llm.py 65.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
8
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
9

10
import cloudpickle
11
import torch.nn as nn
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, deprecated
14

15
16
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
17
18
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
19
20
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
21
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
22
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
23
                                         ChatTemplateContentFormatOption,
24
25
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
26
27
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
28
29
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
30
from vllm.entrypoints.utils import _validate_truncation_size
31
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
32
from vllm.inputs.parse import parse_and_batch_prompt
33
from vllm.logger import init_logger
34
from vllm.lora.request import LoRARequest
35
36
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
37
from vllm.model_executor.layers.quantization import QuantizationMethods
38
39
40
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
41
from vllm.pooling_params import PoolingParams
42
from vllm.prompt_adapter.request import PromptAdapterRequest
43
44
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
45
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
46
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
47
from vllm.usage.usage_lib import UsageContext
48
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
49

50
51
52
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

53
54
logger = init_logger(__name__)

55
56
_R = TypeVar("_R", default=Any)

57
58

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
63
64
65
66
67
68
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
69
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
70
71
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
72
73
74
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
75
76
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
77
78
79
80
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
87
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
88
        quantization: The method used to quantize the model weights. Currently,
89
            we support "awq", "gptq", and "fp8" (experimental).
90
91
92
93
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
94
95
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
96
97
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
98
99
100
101
102
103
104
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
105
106
107
108
109
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
110
111
112
113
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
114
115
116
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
117
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
118
            When a sequence has context length larger than this, we fall back
119
120
121
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
122
123
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
124
125
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
126
        hf_token: The token to use as HTTP bearer authorization for remote files
127
            . If `True`, will use the token generated when running
128
            `huggingface-cli login` (stored in `~/.huggingface`).
129
130
131
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
132
133
134
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
135
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
136

137
138
    Note:
        This class is intended to be used for offline inference. For online
139
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
140
    """
141

142
    DEPRECATE_LEGACY: ClassVar[bool] = True
143
144
145
146
147
148
149
150
151
152
153
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

154
155
156
    def __init__(
        self,
        model: str,
157
158
        *,
        task: TaskOption = "auto",
159
        tokenizer: Optional[str] = None,
160
        tokenizer_mode: TokenizerMode = "auto",
161
        skip_tokenizer_init: bool = False,
162
        trust_remote_code: bool = False,
163
        allowed_local_media_path: str = "",
164
        tensor_parallel_size: int = 1,
165
166
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
167
        revision: Optional[str] = None,
168
        tokenizer_revision: Optional[str] = None,
169
        seed: Optional[int] = None,
170
        gpu_memory_utilization: float = 0.9,
171
        swap_space: float = 4,
172
        cpu_offload_gb: float = 0,
173
        enforce_eager: bool = False,
174
        max_seq_len_to_capture: int = 8192,
175
        disable_custom_all_reduce: bool = False,
176
        disable_async_output_proc: bool = False,
177
        hf_token: Optional[Union[bool, str]] = None,
178
        hf_overrides: Optional[HfOverrides] = None,
179
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
180
        override_pooler_config: Optional[PoolerConfig] = None,
181
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
182
183
        **kwargs,
    ) -> None:
184
        """LLM constructor."""
185

186
187
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
188

189
190
191
192
193
194
195
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

196
197
198
        if hf_overrides is None:
            hf_overrides = {}

199
        if compilation_config is not None:
200
201
202
203
204
205
206
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
207
208
            else:
                compilation_config_instance = compilation_config
209
        else:
210
            compilation_config_instance = CompilationConfig()
211

Zhuohan Li's avatar
Zhuohan Li committed
212
        engine_args = EngineArgs(
213
            model=model,
214
            task=task,
215
            tokenizer=tokenizer,
216
            tokenizer_mode=tokenizer_mode,
217
            skip_tokenizer_init=skip_tokenizer_init,
218
            trust_remote_code=trust_remote_code,
219
            allowed_local_media_path=allowed_local_media_path,
220
221
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
222
            quantization=quantization,
223
            revision=revision,
224
            tokenizer_revision=tokenizer_revision,
225
226
227
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
228
            cpu_offload_gb=cpu_offload_gb,
229
            enforce_eager=enforce_eager,
230
            max_seq_len_to_capture=max_seq_len_to_capture,
231
            disable_custom_all_reduce=disable_custom_all_reduce,
232
            disable_async_output_proc=disable_async_output_proc,
233
            hf_token=hf_token,
234
            hf_overrides=hf_overrides,
235
            mm_processor_kwargs=mm_processor_kwargs,
236
            override_pooler_config=override_pooler_config,
237
            compilation_config=compilation_config_instance,
238
239
            **kwargs,
        )
240
241
242
243
244

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
245

246
        self.request_counter = Counter()
247
        self.default_sampling_params: Union[dict[str, Any], None] = None
248

249
250
251
252
253
254
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
255
256

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
257
        tokenizer_group = self.llm_engine.get_tokenizer_group()
258

259
260
261
262
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
263
            tokenizer_group.tokenizer = tokenizer
264
        else:
265
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
266

267
    def get_default_sampling_params(self) -> SamplingParams:
268
269
270
271
272
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
273
274
        return SamplingParams()

275
276
277
278
279
280
281
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
282
        *,
283
        use_tqdm: bool = True,
284
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
285
286
287
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
288
    ) -> list[RequestOutput]:
289
290
        ...

291
    @overload  # LEGACY: single (prompt + optional token ids)
292
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
293
294
295
296
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
297
298
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
299
        use_tqdm: bool = True,
300
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
301
302
303
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
304
    ) -> list[RequestOutput]:
305
306
307
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
308
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
309
310
    def generate(
        self,
311
        prompts: list[str],
312
        sampling_params: Optional[Union[SamplingParams,
313
314
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
315
        use_tqdm: bool = True,
316
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
317
318
319
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
320
    ) -> list[RequestOutput]:
321
322
323
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
324
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
325
326
327
328
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
329
                                        list[SamplingParams]]] = None,
330
        *,
331
        prompt_token_ids: list[int],
332
        use_tqdm: bool = True,
333
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
334
335
336
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
337
    ) -> list[RequestOutput]:
338
339
340
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
341
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
342
343
    def generate(
        self,
344
        prompts: Optional[list[str]] = None,
345
        sampling_params: Optional[Union[SamplingParams,
346
                                        list[SamplingParams]]] = None,
347
        *,
348
        prompt_token_ids: list[list[int]],
349
        use_tqdm: bool = True,
350
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
351
352
353
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
354
    ) -> list[RequestOutput]:
355
356
357
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
358
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
359
360
361
362
    def generate(
        self,
        prompts: None,
        sampling_params: None,
363
        prompt_token_ids: Union[list[int], list[list[int]]],
364
        use_tqdm: bool = True,
365
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
366
367
368
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
369
    ) -> list[RequestOutput]:
370
371
        ...

nunjunj's avatar
nunjunj committed
372
373
374
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
375
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
376
    )
377
378
    def generate(
        self,
379
        prompts: Union[Union[PromptType, Sequence[PromptType]],
380
                       Optional[Union[str, list[str]]]] = None,
381
382
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
383
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
384
        use_tqdm: bool = True,
385
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
386
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
387
        guided_options_request: Optional[Union[LLMGuidedOptions,
388
                                               GuidedDecodingRequest]] = None,
389
390
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
391
392
        """Generates the completions for the input prompts.

393
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
394
395
396
397
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
398
            prompts: The prompts to the LLM. You may pass a sequence of prompts
399
                for batch inference. See [PromptType][vllm.inputs.PromptType]
400
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
401
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
402
403
404
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
405
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
406
            use_tqdm: Whether to use tqdm to display the progress bar.
407
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
408
            prompt_adapter_request: Prompt Adapter request to use for
409
                generation, if any.
410
411
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413

        Returns:
414
            A list of `RequestOutput` objects containing the
415
            generated completions in the same order as the input prompts.
416

417
418
419
420
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
421
        """
422
        runner_type = self.llm_engine.model_config.runner_type
423
        if runner_type not in ["generate", "transcription"]:
424
            messages = [
425
                "LLM.generate() is only supported for (conditional) generation "
426
427
428
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

429
430
431
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
432
                messages.append(
433
434
435
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
436
437

            raise ValueError(" ".join(messages))
438

439
        if prompt_token_ids is not None:
440
            parsed_prompts = self._convert_v1_inputs(
441
                prompts=cast(Optional[Union[str, list[str]]], prompts),
442
443
444
                prompt_token_ids=prompt_token_ids,
            )
        else:
445
446
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
447

448
449
450
451
452
453
454
455
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

456
457
        if sampling_params is None:
            # Use default sampling params.
458
            sampling_params = self.get_default_sampling_params()
459

460
        self._validate_and_add_requests(
461
            prompts=parsed_prompts,
462
            params=sampling_params,
463
            use_tqdm=use_tqdm,
464
            lora_request=lora_request,
465
            prompt_adapter_request=prompt_adapter_request,
466
            guided_options=guided_options_request,
467
468
            priority=priority,
        )
469

470
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
471
        return self.engine_class.validate_outputs(outputs, RequestOutput)
472

473
    def collective_rpc(self,
474
                       method: Union[str, Callable[..., _R]],
475
                       timeout: Optional[float] = None,
476
477
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
478
479
480
481
482
483
484
485
486
487
488
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
489
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
490
491
492
493
494
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
495

496
497
498
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
499
        """
500
501

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
502
503

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
504
        """
505
506
        Run a function directly on the model inside each worker,
        returning the result for each of them.
507
        """
508
509
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
510

511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")
            return lora_request

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

528
529
    def beam_search(
        self,
530
        prompts: list[Union[TokensPrompt, TextPrompt]],
531
        params: BeamSearchParams,
532
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
533
    ) -> list[BeamSearchOutput]:
534
535
536
537
538
539
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
540
            params: The beam search parameters.
541
            lora_request: LoRA request to use for generation, if any.
542
        """
543
544
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
545
546
547
548
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
549
550
        length_penalty = params.length_penalty

551
552
553
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

554
555
556
557
        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
558

559
560
561
562
563
564
565
566
567
568
569
570
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
571

572
573
574
575
576
577
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
578
                                            temperature=temperature)
579
        instances: list[BeamSearchInstance] = []
580

581
        for lora_req, prompt in zip(lora_requests, prompts):
582
583
584
585
586
587
588
589
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

590
591
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
592
593
594
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
595

596
            instances.append(
597
598
599
600
601
602
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
603
604

        for _ in range(max_tokens):
605
            all_beams: list[BeamSearchSequence] = list(
606
607
608
609
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
610
            instance_start_and_end: list[tuple[int, int]] = list(
611
612
613
614
615
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

616
617
618
619
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
620
621
622
623
624

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
625
626
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
643
                                logprobs=current_beam.logprobs + [logprobs],
644
                                lora_request=current_beam.lora_request,
645
                                cum_logprob=current_beam.cum_logprob +
646
647
648
649
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
650
651
652
653
654
655
656

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
657
                                      key=sort_beams_key,
658
659
660
661
662
663
664
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
665
                                      key=sort_beams_key,
666
667
668
669
670
671
672
673
674
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
675
676
    def chat(
        self,
677
678
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
679
        sampling_params: Optional[Union[SamplingParams,
680
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
681
682
683
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
684
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
685
        add_generation_prompt: bool = True,
686
        continue_final_message: bool = False,
687
        tools: Optional[list[dict[str, Any]]] = None,
688
        chat_template_kwargs: Optional[dict[str, Any]] = None,
689
690
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
691
        """
692
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
693

694
        The chat conversation is converted into a text prompt using the
695
        tokenizer and calls the [generate][] method to generate the
696
697
698
699
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
700
701

        Args:
702
703
            messages: A list of conversations or a single conversation.

704
705
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
706

nunjunj's avatar
nunjunj committed
707
708
709
710
711
712
713
714
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
715
                If not provided, the model's default chat template will be used.
716
717
            chat_template_content_format: The format to render message content.

718
719
720
721
722
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
723

724
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
725
                to each message.
726
            continue_final_message: If True, continues the final message in
727
                the conversation instead of starting a new one. Cannot be
728
                `True` if `add_generation_prompt` is also `True`.
729
730
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
731
732
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
733
734

        Returns:
735
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
736
737
            responses in the same order as the input messages.
        """
738
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
739

740
741
        # Handle multi and single conversations
        if is_list_of(messages, list):
742
743
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
744
                                    messages)
745
        else:
746
            # messages is list[...]
747
            list_of_messages = [
748
                cast(list[ChatCompletionMessageParam], messages)
749
            ]
750

751
        tokenizer = self.get_tokenizer(lora_request)
752
753
754
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
755
            tools,
756
757
            chat_template_content_format,
            tokenizer,
758
            model_config=model_config,
759
760
        )

761
762
763
764
765
766
767
768
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

769
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
770
771

        for msgs in list_of_messages:
772
773
774
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
775
            conversation, mm_data = parse_chat_messages(
776
777
778
779
780
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
781
782

            if isinstance(tokenizer, MistralTokenizer):
783
                prompt_token_ids = apply_mistral_chat_template(
784
785
                    tokenizer,
                    messages=msgs,
786
                    **_chat_template_kwargs,
787
788
                )
            else:
789
                prompt_str = apply_hf_chat_template(
790
                    tokenizer=tokenizer,
791
                    conversation=conversation,
792
                    model_config=model_config,
793
                    **_chat_template_kwargs,
794
                )
795
796
797
798
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
799

800
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
801
802
803
804

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

805
806
807
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

808
            prompts.append(prompt)
809

nunjunj's avatar
nunjunj committed
810
        return self.generate(
811
            prompts,
812
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
813
814
815
816
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

817
818
819
820
821
822
823
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
824
        *,
825
        truncate_prompt_tokens: Optional[int] = None,
826
        use_tqdm: bool = True,
827
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
828
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
829
    ) -> list[PoolingRequestOutput]:
830
831
        ...

832
    @overload  # LEGACY: single (prompt + optional token ids)
833
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
834
835
836
837
838
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
839
        prompt_token_ids: Optional[list[int]] = None,
840
        truncate_prompt_tokens: Optional[int] = None,
841
        use_tqdm: bool = True,
842
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
843
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
844
    ) -> list[PoolingRequestOutput]:
845
        ...
846

847
    @overload  # LEGACY: multi (prompt + optional token ids)
848
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
849
850
    def encode(
        self,
851
        prompts: list[str],
852
        pooling_params: Optional[Union[PoolingParams,
853
                                       Sequence[PoolingParams]]] = None,
854
        prompt_token_ids: Optional[list[list[int]]] = None,
855
        truncate_prompt_tokens: Optional[int] = None,
856
        use_tqdm: bool = True,
857
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
858
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
859
    ) -> list[PoolingRequestOutput]:
860
861
862
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
863
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
864
865
866
867
868
869
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
870
        prompt_token_ids: list[int],
871
        truncate_prompt_tokens: Optional[int] = None,
872
        use_tqdm: bool = True,
873
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
874
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
875
    ) -> list[PoolingRequestOutput]:
876
877
878
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
879
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
880
881
    def encode(
        self,
882
        prompts: Optional[list[str]] = None,
883
884
885
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
886
        prompt_token_ids: list[list[int]],
887
        truncate_prompt_tokens: Optional[int] = None,
888
        use_tqdm: bool = True,
889
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
890
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
891
    ) -> list[PoolingRequestOutput]:
892
893
894
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
895
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
896
897
898
899
    def encode(
        self,
        prompts: None,
        pooling_params: None,
900
        prompt_token_ids: Union[list[int], list[list[int]]],
901
        truncate_prompt_tokens: Optional[int] = None,
902
        use_tqdm: bool = True,
903
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
904
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
905
    ) -> list[PoolingRequestOutput]:
906
907
        ...

nunjunj's avatar
nunjunj committed
908
909
910
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
911
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
912
    )
913
914
    def encode(
        self,
915
        prompts: Union[Union[PromptType, Sequence[PromptType]],
916
                       Optional[Union[str, list[str]]]] = None,
917
918
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
919
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
920
        truncate_prompt_tokens: Optional[int] = None,
921
        use_tqdm: bool = True,
922
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
923
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
924
    ) -> list[PoolingRequestOutput]:
925
926
        """Apply pooling to the hidden states corresponding to the input
        prompts.
927

928
        This class automatically batches the given prompts, considering
929
930
931
932
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
933
            prompts: The prompts to the LLM. You may pass a sequence of prompts
934
                for batch inference. See [PromptType][vllm.inputs.PromptType]
935
                for more details about the format of each prompts.
936
937
938
939
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
940
            prompt_adapter_request: Prompt Adapter request to use for
941
                generation, if any.
942
943

        Returns:
944
            A list of `PoolingRequestOutput` objects containing the
945
            pooled hidden states in the same order as the input prompts.
946

947
948
949
950
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
951
        """
952
953
954
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
955

956
957
958
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
959
                messages.append(
960
961
962
963
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
964
965

            raise ValueError(" ".join(messages))
966

967
        if prompt_token_ids is not None:
968
            parsed_prompts = self._convert_v1_inputs(
969
                prompts=cast(Optional[Union[str, list[str]]], prompts),
970
971
972
                prompt_token_ids=prompt_token_ids,
            )
        else:
973
974
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
975

976
977
978
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
979
980
981
982
983
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
984

985
986
987
988
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

989
        self._validate_and_add_requests(
990
            prompts=parsed_prompts,
991
            params=pooling_params,
992
            use_tqdm=use_tqdm,
993
            lora_request=lora_request,
994
            tokenization_kwargs=tokenization_kwargs,
995
            prompt_adapter_request=prompt_adapter_request,
996
997
        )

998
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
999
        return self.engine_class.validate_outputs(outputs,
1000
                                                  PoolingRequestOutput)
1001

1002
1003
1004
1005
1006
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1007
        truncate_prompt_tokens: Optional[int] = None,
1008
        use_tqdm: bool = True,
1009
1010
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1011
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1012
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1013
    ) -> list[EmbeddingRequestOutput]:
1014
1015
1016
1017
1018
1019
1020
1021
1022
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1023
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1024
                for more details about the format of each prompts.
1025
1026
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1027
1028
1029
1030
1031
1032
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1033
            A list of `EmbeddingRequestOutput` objects containing the
1034
1035
1036
1037
1038
1039
1040
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1041
                            truncate_prompt_tokens=truncate_prompt_tokens,
1042
                            use_tqdm=use_tqdm,
1043
                            pooling_params=pooling_params,
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1055
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1056
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1057
    ) -> list[ClassificationRequestOutput]:
1058
1059
1060
1061
1062
1063
1064
1065
1066
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1067
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1068
1069
1070
1071
1072
1073
1074
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1075
            A list of `ClassificationRequestOutput` objects containing the
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1089
1090
1091
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1092
1093
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1094
1095
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1096
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1097
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1098
    ) -> list[ScoringRequestOutput]:
1099

1100
        encoded_output: list[PoolingRequestOutput] = self.encode(
1101
            text_1 + text_2,
1102
            truncate_prompt_tokens=truncate_prompt_tokens,
1103
1104
1105
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1106

1107
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1108
            0:len(text_1)]
1109
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1110
            len(text_1):]
1111
1112
1113
1114

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1115
1116
1117
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1118
1119
1120
1121
1122
1123
1124

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1125
        tokenizer: AnyTokenizer,
1126
1127
        text_1: list[str],
        text_2: list[str],
1128
1129
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1130
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1131
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1132
    ) -> list[ScoringRequestOutput]:
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1145
        tokenization_kwargs: dict[str, Any] = {}
1146
1147
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1163
            use_tqdm=use_tqdm,
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1174
1175
1176
1177
1178
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1179
        *,
1180
1181
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1182
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1183
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1184
    ) -> list[ScoringRequestOutput]:
1185
        """Generate similarity scores for all pairs `<text,text_pair>`.
1186

1187
1188
1189
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1190
        The input pairs are used to build a list of prompts for the
1191
1192
1193
1194
1195
1196
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1197
                case it has to have the same length as the `text_2` list
1198
            text_2: The texts to pair with the query to form the input
1199
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1200
1201
1202
1203
1204
1205
1206
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1207
            A list of `ScoringRequestOutput` objects containing the
1208
1209
            generated scores in the same order as the input prompts.
        """
1210
1211
1212
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1213

1214
1215
1216
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1217
                messages.append(
1218
1219
1220
1221
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1222
1223
1224

            raise ValueError(" ".join(messages))

1225
        if self.llm_engine.model_config.task not in ("embed", "score"):
1226
            raise ValueError(
1227
                "Score API is only enabled for `--task embed or --task score`")
1228
1229
1230
1231

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1232
1233
        tokenizer = self.llm_engine.get_tokenizer()

1234
1235
1236
1237
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1238
                                     "supported for scoring")
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1250
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1251
1252
1253
1254

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1255
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1256

1257
        _validate_score_input_lens(input_text_1, input_text_2)
1258

1259
        if self.llm_engine.model_config.is_cross_encoder:
1260
1261
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1262
1263
1264
1265
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1266
1267
1268
1269
1270
1271
1272
1273
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1274

1275
1276
1277
1278
1279
1280
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1281
1282
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1283

1284
1285
1286
1287
1288
1289
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1302
        """
1303
        self.reset_prefix_cache()
1304
1305
        self.llm_engine.sleep(level=level)

1306
    def wake_up(self, tags: Optional[list[str]] = None):
1307
        """
1308
        Wake up the engine from sleep mode. See the [sleep][] method
1309
1310
1311
1312
1313
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1314
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1315
1316
1317
1318
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1319

1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1334
1335
    # LEGACY
    def _convert_v1_inputs(
1336
        self,
1337
1338
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1339
1340
    ):
        # skip_tokenizer_init is now checked in engine
1341

1342
1343
1344
1345
1346
1347
1348
1349
1350
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1351
1352
1353
1354
1355
1356
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1357
1358
        if prompts is not None:
            num_requests = len(prompts)
1359
        elif prompt_token_ids is not None:
1360
            num_requests = len(prompt_token_ids)
1361
        parsed_prompts: list[PromptType] = []
1362
        for i in range(num_requests):
1363
            item: PromptType
1364

1365
            if prompts is not None:
1366
1367
1368
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1369
            else:
1370
                raise AssertionError
1371

1372
            parsed_prompts.append(item)
1373

1374
        return parsed_prompts
1375
1376
1377

    def _validate_and_add_requests(
        self,
1378
        prompts: Union[PromptType, Sequence[PromptType]],
1379
1380
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1381
1382
        *,
        use_tqdm: bool,
1383
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1384
        prompt_adapter_request: Optional[PromptAdapterRequest],
1385
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1386
        guided_options: Optional[GuidedDecodingRequest] = None,
1387
        priority: Optional[list[int]] = None,
1388
    ) -> None:
1389
1390
1391
1392
1393
1394
1395
1396
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1397
        if isinstance(prompts, (str, dict)):
1398
            # Convert a single prompt to a list.
1399
            prompts = [prompts]
1400

1401
        num_requests = len(prompts)
1402
1403
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1404
                             "must be the same.")
1405
1406
1407
1408
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1409

1410
1411
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1412
                self._add_guided_params(sp, guided_options)
1413
1414
1415

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1416

Zhuohan Li's avatar
Zhuohan Li committed
1417
        # Add requests to the engine.
1418
1419
1420
1421
1422
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1423
            self._add_request(
1424
                prompt,
1425
                params[i] if isinstance(params, Sequence) else params,
1426
                tokenization_kwargs=tokenization_kwargs,
1427
1428
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1429
                prompt_adapter_request=prompt_adapter_request,
1430
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1431
            )
1432

1433
    def _add_request(
nunjunj's avatar
nunjunj committed
1434
        self,
1435
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1436
        params: Union[SamplingParams, PoolingParams],
1437
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1438
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1439
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1440
        priority: int = 0,
1441
1442
    ) -> None:
        request_id = str(next(self.request_counter))
1443
1444
        self.llm_engine.add_request(
            request_id,
1445
            prompt,
1446
1447
            params,
            lora_request=lora_request,
1448
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1449
            prompt_adapter_request=prompt_adapter_request,
1450
            priority=priority,
nunjunj's avatar
nunjunj committed
1451
        )
1452

1453
    def _add_guided_params(
1454
1455
1456
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1457
1458
1459
1460
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1461
            raise ValueError("Cannot set both guided_options_request and "
1462
1463
1464
1465
1466
1467
1468
1469
1470
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1471
1472
1473
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1474
1475
        return params

1476
    def _run_engine(
1477
            self, *, use_tqdm: bool
1478
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1479
1480
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1481
            num_requests = self.llm_engine.get_num_unfinished_requests()
1482
1483
1484
1485
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1486
1487
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1488
            )
1489

Zhuohan Li's avatar
Zhuohan Li committed
1490
        # Run the engine.
1491
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1492
1493
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1494
1495
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1496
            for output in step_outputs:
1497
                if output.finished:
1498
1499
                    outputs.append(output)
                    if use_tqdm:
1500
1501
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1502
                            n = len(output.outputs)
1503
                            assert output.prompt_token_ids is not None
1504
                            total_in_toks += len(output.prompt_token_ids) * n
1505
1506
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1507
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1508
1509
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1510
1511
1512
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1513
                            pbar.update(n)
1514
1515
                        else:
                            pbar.update(1)
1516

1517
1518
        if use_tqdm:
            pbar.close()
1519
1520
1521
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1522
        return sorted(outputs, key=lambda x: int(x.request_id))