"tests/models/test_llava_image_embeds.py" did not exist on "8ea5e44a435e8731fd6f5ba4c329dd112752532a"
llm.py 65.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from tqdm.auto import tqdm
14
from typing_extensions import TypeVar, deprecated
15

16
17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
18
19
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
20
21
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
31
from vllm.entrypoints.utils import _validate_truncation_size
32
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
33
from vllm.inputs.parse import parse_and_batch_prompt
34
from vllm.logger import init_logger
35
from vllm.lora.request import LoRARequest
36
37
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
38
from vllm.model_executor.layers.quantization import QuantizationMethods
39
40
41
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
42
from vllm.pooling_params import PoolingParams
43
from vllm.prompt_adapter.request import PromptAdapterRequest
44
45
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
46
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
47
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
48
from vllm.usage.usage_lib import UsageContext
49
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
50

51
52
53
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

54
55
logger = init_logger(__name__)

56
57
_R = TypeVar("_R", default=Any)

58
59

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
70
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
71
72
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
73
74
75
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
76
77
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
78
79
80
81
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
89
        quantization: The method used to quantize the model weights. Currently,
90
            we support "awq", "gptq", and "fp8" (experimental).
91
92
93
94
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
95
96
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
97
98
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
99
100
101
102
103
104
105
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
106
107
108
109
110
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
111
112
113
114
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
115
116
117
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
118
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
119
            When a sequence has context length larger than this, we fall back
120
121
122
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
123
124
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
125
126
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
127
        hf_token: The token to use as HTTP bearer authorization for remote files
128
            . If `True`, will use the token generated when running
129
            `huggingface-cli login` (stored in `~/.huggingface`).
130
131
132
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
133
134
135
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
136
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
137

138
139
    Note:
        This class is intended to be used for offline inference. For online
140
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
141
    """
142

143
    DEPRECATE_LEGACY: ClassVar[bool] = True
144
145
146
147
148
149
150
151
152
153
154
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

155
156
157
    def __init__(
        self,
        model: str,
158
159
        *,
        task: TaskOption = "auto",
160
        tokenizer: Optional[str] = None,
161
        tokenizer_mode: TokenizerMode = "auto",
162
        skip_tokenizer_init: bool = False,
163
        trust_remote_code: bool = False,
164
        allowed_local_media_path: str = "",
165
        tensor_parallel_size: int = 1,
166
167
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
168
        revision: Optional[str] = None,
169
        tokenizer_revision: Optional[str] = None,
170
        seed: Optional[int] = None,
171
        gpu_memory_utilization: float = 0.9,
172
        swap_space: float = 4,
173
        cpu_offload_gb: float = 0,
174
        enforce_eager: bool = False,
175
        max_seq_len_to_capture: int = 8192,
176
        disable_custom_all_reduce: bool = False,
177
        disable_async_output_proc: bool = False,
178
        hf_token: Optional[Union[bool, str]] = None,
179
        hf_overrides: Optional[HfOverrides] = None,
180
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
181
        override_pooler_config: Optional[PoolerConfig] = None,
182
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
183
184
        **kwargs,
    ) -> None:
185
        """LLM constructor."""
186

187
188
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
189

190
191
192
193
194
195
196
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

197
198
199
        if hf_overrides is None:
            hf_overrides = {}

200
        if compilation_config is not None:
201
202
203
204
205
206
207
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
208
209
            else:
                compilation_config_instance = compilation_config
210
        else:
211
            compilation_config_instance = CompilationConfig()
212

Zhuohan Li's avatar
Zhuohan Li committed
213
        engine_args = EngineArgs(
214
            model=model,
215
            task=task,
216
            tokenizer=tokenizer,
217
            tokenizer_mode=tokenizer_mode,
218
            skip_tokenizer_init=skip_tokenizer_init,
219
            trust_remote_code=trust_remote_code,
220
            allowed_local_media_path=allowed_local_media_path,
221
222
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
223
            quantization=quantization,
224
            revision=revision,
225
            tokenizer_revision=tokenizer_revision,
226
227
228
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
229
            cpu_offload_gb=cpu_offload_gb,
230
            enforce_eager=enforce_eager,
231
            max_seq_len_to_capture=max_seq_len_to_capture,
232
            disable_custom_all_reduce=disable_custom_all_reduce,
233
            disable_async_output_proc=disable_async_output_proc,
234
            hf_token=hf_token,
235
            hf_overrides=hf_overrides,
236
            mm_processor_kwargs=mm_processor_kwargs,
237
            override_pooler_config=override_pooler_config,
238
            compilation_config=compilation_config_instance,
239
240
            **kwargs,
        )
241
242
243
244
245

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
246

247
        self.request_counter = Counter()
248
        self.default_sampling_params: Union[dict[str, Any], None] = None
249

250
251
252
253
254
255
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
256
257

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
258
        tokenizer_group = self.llm_engine.get_tokenizer_group()
259

260
261
262
263
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
264
            tokenizer_group.tokenizer = tokenizer
265
        else:
266
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
267

268
    def get_default_sampling_params(self) -> SamplingParams:
269
270
271
272
273
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
274
275
        return SamplingParams()

276
277
278
279
280
281
282
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
283
        *,
284
        use_tqdm: bool = True,
285
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
286
287
288
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
289
    ) -> list[RequestOutput]:
290
291
        ...

292
    @overload  # LEGACY: single (prompt + optional token ids)
293
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
294
295
296
297
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
298
299
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
300
        use_tqdm: bool = True,
301
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
302
303
304
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
305
    ) -> list[RequestOutput]:
306
307
308
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
309
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
310
311
    def generate(
        self,
312
        prompts: list[str],
313
        sampling_params: Optional[Union[SamplingParams,
314
315
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
316
        use_tqdm: bool = True,
317
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
318
319
320
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
321
    ) -> list[RequestOutput]:
322
323
324
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
325
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
326
327
328
329
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
330
                                        list[SamplingParams]]] = None,
331
        *,
332
        prompt_token_ids: list[int],
333
        use_tqdm: bool = True,
334
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
335
336
337
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
338
    ) -> list[RequestOutput]:
339
340
341
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
342
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
343
344
    def generate(
        self,
345
        prompts: Optional[list[str]] = None,
346
        sampling_params: Optional[Union[SamplingParams,
347
                                        list[SamplingParams]]] = None,
348
        *,
349
        prompt_token_ids: list[list[int]],
350
        use_tqdm: bool = True,
351
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
352
353
354
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
355
    ) -> list[RequestOutput]:
356
357
358
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
359
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
360
361
362
363
    def generate(
        self,
        prompts: None,
        sampling_params: None,
364
        prompt_token_ids: Union[list[int], list[list[int]]],
365
        use_tqdm: bool = True,
366
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
367
368
369
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
370
    ) -> list[RequestOutput]:
371
372
        ...

nunjunj's avatar
nunjunj committed
373
374
375
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
376
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
377
    )
378
379
    def generate(
        self,
380
        prompts: Union[Union[PromptType, Sequence[PromptType]],
381
                       Optional[Union[str, list[str]]]] = None,
382
383
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
384
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
385
        use_tqdm: bool = True,
386
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
387
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
388
        guided_options_request: Optional[Union[LLMGuidedOptions,
389
                                               GuidedDecodingRequest]] = None,
390
391
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
        """Generates the completions for the input prompts.

394
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
397
398
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
399
            prompts: The prompts to the LLM. You may pass a sequence of prompts
400
                for batch inference. See [PromptType][vllm.inputs.PromptType]
401
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
402
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
403
404
405
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
406
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
407
            use_tqdm: Whether to use tqdm to display the progress bar.
408
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
409
            prompt_adapter_request: Prompt Adapter request to use for
410
                generation, if any.
411
412
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414

        Returns:
415
            A list of `RequestOutput` objects containing the
416
            generated completions in the same order as the input prompts.
417

418
419
420
421
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
422
        """
423
        runner_type = self.llm_engine.model_config.runner_type
424
        if runner_type not in ["generate", "transcription"]:
425
            messages = [
426
                "LLM.generate() is only supported for (conditional) generation "
427
428
429
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

430
431
432
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
433
                messages.append(
434
435
436
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
437
438

            raise ValueError(" ".join(messages))
439

440
        if prompt_token_ids is not None:
441
            parsed_prompts = self._convert_v1_inputs(
442
                prompts=cast(Optional[Union[str, list[str]]], prompts),
443
444
445
                prompt_token_ids=prompt_token_ids,
            )
        else:
446
447
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
448

449
450
451
452
453
454
455
456
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

457
458
        if sampling_params is None:
            # Use default sampling params.
459
            sampling_params = self.get_default_sampling_params()
460

461
        self._validate_and_add_requests(
462
            prompts=parsed_prompts,
463
            params=sampling_params,
464
            use_tqdm=use_tqdm,
465
            lora_request=lora_request,
466
            prompt_adapter_request=prompt_adapter_request,
467
            guided_options=guided_options_request,
468
469
            priority=priority,
        )
470

471
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
472
        return self.engine_class.validate_outputs(outputs, RequestOutput)
473

474
    def collective_rpc(self,
475
                       method: Union[str, Callable[..., _R]],
476
                       timeout: Optional[float] = None,
477
478
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
479
480
481
482
483
484
485
486
487
488
489
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
490
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
491
492
493
494
495
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
496

497
498
499
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
500
        """
501
502

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
503
504

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
505
        """
506
507
        Run a function directly on the model inside each worker,
        returning the result for each of them.
508
        """
509
510
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")
            return lora_request

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

529
530
    def beam_search(
        self,
531
        prompts: list[Union[TokensPrompt, TextPrompt]],
532
        params: BeamSearchParams,
533
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
534
    ) -> list[BeamSearchOutput]:
535
536
537
538
539
540
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
541
            params: The beam search parameters.
542
            lora_request: LoRA request to use for generation, if any.
543
        """
544
545
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
546
547
548
549
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
550
551
        length_penalty = params.length_penalty

552
553
554
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

555
556
557
558
        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
559

560
561
562
563
564
565
566
567
568
569
570
571
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
572

573
574
575
576
577
578
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
579
                                            temperature=temperature)
580
        instances: list[BeamSearchInstance] = []
581

582
        for lora_req, prompt in zip(lora_requests, prompts):
583
584
585
586
587
588
589
590
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

591
592
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
593
594
595
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
596

597
            instances.append(
598
599
600
601
602
603
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
604
605

        for _ in range(max_tokens):
606
            all_beams: list[BeamSearchSequence] = list(
607
608
609
610
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
611
            instance_start_and_end: list[tuple[int, int]] = list(
612
613
614
615
616
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

617
618
619
620
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
621
622
623
624
625

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
626
627
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
644
                                logprobs=current_beam.logprobs + [logprobs],
645
                                lora_request=current_beam.lora_request,
646
                                cum_logprob=current_beam.cum_logprob +
647
648
649
650
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
651
652
653
654
655
656
657

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
658
                                      key=sort_beams_key,
659
660
661
662
663
664
665
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
666
                                      key=sort_beams_key,
667
668
669
670
671
672
673
674
675
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
676
677
    def chat(
        self,
678
679
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
680
        sampling_params: Optional[Union[SamplingParams,
681
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
682
683
684
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
685
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
686
        add_generation_prompt: bool = True,
687
        continue_final_message: bool = False,
688
        tools: Optional[list[dict[str, Any]]] = None,
689
        chat_template_kwargs: Optional[dict[str, Any]] = None,
690
691
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
692
        """
693
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
694

695
        The chat conversation is converted into a text prompt using the
696
        tokenizer and calls the [generate][] method to generate the
697
698
699
700
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
701
702

        Args:
703
704
            messages: A list of conversations or a single conversation.

705
706
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
707

nunjunj's avatar
nunjunj committed
708
709
710
711
712
713
714
715
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
716
                If not provided, the model's default chat template will be used.
717
718
            chat_template_content_format: The format to render message content.

719
720
721
722
723
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
724

725
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
726
                to each message.
727
            continue_final_message: If True, continues the final message in
728
                the conversation instead of starting a new one. Cannot be
729
                `True` if `add_generation_prompt` is also `True`.
730
731
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
732
733
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
734
735

        Returns:
736
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
737
738
            responses in the same order as the input messages.
        """
739
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
740

741
742
        # Handle multi and single conversations
        if is_list_of(messages, list):
743
744
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
745
                                    messages)
746
        else:
747
            # messages is list[...]
748
            list_of_messages = [
749
                cast(list[ChatCompletionMessageParam], messages)
750
            ]
751

752
        tokenizer = self.get_tokenizer(lora_request)
753
754
755
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
756
            tools,
757
758
            chat_template_content_format,
            tokenizer,
759
            model_config=model_config,
760
761
        )

762
763
764
765
766
767
768
769
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

770
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
771
772

        for msgs in list_of_messages:
773
774
775
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
776
            conversation, mm_data = parse_chat_messages(
777
778
779
780
781
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
782
783

            if isinstance(tokenizer, MistralTokenizer):
784
                prompt_token_ids = apply_mistral_chat_template(
785
786
                    tokenizer,
                    messages=msgs,
787
                    **_chat_template_kwargs,
788
789
                )
            else:
790
                prompt_str = apply_hf_chat_template(
791
                    tokenizer=tokenizer,
792
                    conversation=conversation,
793
                    model_config=model_config,
794
                    **_chat_template_kwargs,
795
                )
796
797
798
799
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
800

801
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
802
803
804
805

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

806
807
808
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

809
            prompts.append(prompt)
810

nunjunj's avatar
nunjunj committed
811
        return self.generate(
812
            prompts,
813
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
814
815
816
817
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

818
819
820
821
822
823
824
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
825
        *,
826
        truncate_prompt_tokens: Optional[int] = None,
827
        use_tqdm: bool = True,
828
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
829
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
830
    ) -> list[PoolingRequestOutput]:
831
832
        ...

833
    @overload  # LEGACY: single (prompt + optional token ids)
834
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
835
836
837
838
839
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
840
        prompt_token_ids: Optional[list[int]] = None,
841
        truncate_prompt_tokens: Optional[int] = None,
842
        use_tqdm: bool = True,
843
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
844
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
845
    ) -> list[PoolingRequestOutput]:
846
        ...
847

848
    @overload  # LEGACY: multi (prompt + optional token ids)
849
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
850
851
    def encode(
        self,
852
        prompts: list[str],
853
        pooling_params: Optional[Union[PoolingParams,
854
                                       Sequence[PoolingParams]]] = None,
855
        prompt_token_ids: Optional[list[list[int]]] = None,
856
        truncate_prompt_tokens: Optional[int] = None,
857
        use_tqdm: bool = True,
858
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
859
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
860
    ) -> list[PoolingRequestOutput]:
861
862
863
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
864
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
865
866
867
868
869
870
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
871
        prompt_token_ids: list[int],
872
        truncate_prompt_tokens: Optional[int] = None,
873
        use_tqdm: bool = True,
874
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
875
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
876
    ) -> list[PoolingRequestOutput]:
877
878
879
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
880
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
881
882
    def encode(
        self,
883
        prompts: Optional[list[str]] = None,
884
885
886
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
887
        prompt_token_ids: list[list[int]],
888
        truncate_prompt_tokens: Optional[int] = None,
889
        use_tqdm: bool = True,
890
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
891
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
892
    ) -> list[PoolingRequestOutput]:
893
894
895
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
896
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
897
898
899
900
    def encode(
        self,
        prompts: None,
        pooling_params: None,
901
        prompt_token_ids: Union[list[int], list[list[int]]],
902
        truncate_prompt_tokens: Optional[int] = None,
903
        use_tqdm: bool = True,
904
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
905
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
906
    ) -> list[PoolingRequestOutput]:
907
908
        ...

nunjunj's avatar
nunjunj committed
909
910
911
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
912
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
913
    )
914
915
    def encode(
        self,
916
        prompts: Union[Union[PromptType, Sequence[PromptType]],
917
                       Optional[Union[str, list[str]]]] = None,
918
919
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
920
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
921
        truncate_prompt_tokens: Optional[int] = None,
922
        use_tqdm: bool = True,
923
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
924
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
925
    ) -> list[PoolingRequestOutput]:
926
927
        """Apply pooling to the hidden states corresponding to the input
        prompts.
928

929
        This class automatically batches the given prompts, considering
930
931
932
933
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
934
            prompts: The prompts to the LLM. You may pass a sequence of prompts
935
                for batch inference. See [PromptType][vllm.inputs.PromptType]
936
                for more details about the format of each prompts.
937
938
939
940
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
941
            prompt_adapter_request: Prompt Adapter request to use for
942
                generation, if any.
943
944

        Returns:
945
            A list of `PoolingRequestOutput` objects containing the
946
            pooled hidden states in the same order as the input prompts.
947

948
949
950
951
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
952
        """
953
954
955
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
956

957
958
959
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
960
                messages.append(
961
962
963
964
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
965
966

            raise ValueError(" ".join(messages))
967

968
        if prompt_token_ids is not None:
969
            parsed_prompts = self._convert_v1_inputs(
970
                prompts=cast(Optional[Union[str, list[str]]], prompts),
971
972
973
                prompt_token_ids=prompt_token_ids,
            )
        else:
974
975
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
976

977
978
979
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
980
981
982
983
984
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
985

986
987
988
989
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

990
        self._validate_and_add_requests(
991
            prompts=parsed_prompts,
992
            params=pooling_params,
993
            use_tqdm=use_tqdm,
994
            lora_request=lora_request,
995
            tokenization_kwargs=tokenization_kwargs,
996
            prompt_adapter_request=prompt_adapter_request,
997
998
        )

999
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1000
        return self.engine_class.validate_outputs(outputs,
1001
                                                  PoolingRequestOutput)
1002

1003
1004
1005
1006
1007
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1008
        truncate_prompt_tokens: Optional[int] = None,
1009
        use_tqdm: bool = True,
1010
1011
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1012
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1013
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1014
    ) -> list[EmbeddingRequestOutput]:
1015
1016
1017
1018
1019
1020
1021
1022
1023
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1024
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1025
                for more details about the format of each prompts.
1026
1027
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1028
1029
1030
1031
1032
1033
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1034
            A list of `EmbeddingRequestOutput` objects containing the
1035
1036
1037
1038
1039
1040
1041
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1042
                            truncate_prompt_tokens=truncate_prompt_tokens,
1043
                            use_tqdm=use_tqdm,
1044
                            pooling_params=pooling_params,
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1056
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1057
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1058
    ) -> list[ClassificationRequestOutput]:
1059
1060
1061
1062
1063
1064
1065
1066
1067
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1068
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1069
1070
1071
1072
1073
1074
1075
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1076
            A list of `ClassificationRequestOutput` objects containing the
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1090
1091
1092
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1093
1094
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1095
1096
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1097
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1098
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1099
    ) -> list[ScoringRequestOutput]:
1100

1101
        encoded_output: list[PoolingRequestOutput] = self.encode(
1102
            text_1 + text_2,
1103
            truncate_prompt_tokens=truncate_prompt_tokens,
1104
1105
1106
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1107

1108
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1109
            0:len(text_1)]
1110
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1111
            len(text_1):]
1112
1113
1114
1115

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1116
1117
1118
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1119
1120
1121
1122
1123
1124
1125

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1126
        tokenizer: AnyTokenizer,
1127
1128
        text_1: list[str],
        text_2: list[str],
1129
1130
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1131
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1132
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1133
    ) -> list[ScoringRequestOutput]:
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1146
        tokenization_kwargs: dict[str, Any] = {}
1147
1148
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1164
            use_tqdm=use_tqdm,
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1175
1176
1177
1178
1179
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1180
        *,
1181
1182
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1183
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1184
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1185
    ) -> list[ScoringRequestOutput]:
1186
        """Generate similarity scores for all pairs `<text,text_pair>`.
1187

1188
1189
1190
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1191
        The input pairs are used to build a list of prompts for the
1192
1193
1194
1195
1196
1197
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1198
                case it has to have the same length as the `text_2` list
1199
            text_2: The texts to pair with the query to form the input
1200
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1201
1202
1203
1204
1205
1206
1207
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1208
            A list of `ScoringRequestOutput` objects containing the
1209
1210
            generated scores in the same order as the input prompts.
        """
1211
1212
1213
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1214

1215
1216
1217
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1218
                messages.append(
1219
1220
1221
1222
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1223
1224
1225

            raise ValueError(" ".join(messages))

1226
        if self.llm_engine.model_config.task not in ("embed", "score"):
1227
            raise ValueError(
1228
                "Score API is only enabled for `--task embed or --task score`")
1229
1230
1231
1232

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1233
1234
        tokenizer = self.llm_engine.get_tokenizer()

1235
1236
1237
1238
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1239
                                     "supported for scoring")
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1251
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1252
1253
1254
1255

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1256
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1257

1258
        _validate_score_input_lens(input_text_1, input_text_2)
1259

1260
        if self.llm_engine.model_config.is_cross_encoder:
1261
1262
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1263
1264
1265
1266
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1267
1268
1269
1270
1271
1272
1273
1274
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1275

1276
1277
1278
1279
1280
1281
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1282
1283
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1284

1285
1286
1287
1288
1289
1290
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1303
        """
1304
        self.reset_prefix_cache()
1305
1306
        self.llm_engine.sleep(level=level)

1307
    def wake_up(self, tags: Optional[list[str]] = None):
1308
        """
1309
        Wake up the engine from sleep mode. See the [sleep][] method
1310
1311
1312
1313
1314
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1315
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1316
1317
1318
1319
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1320

1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1335
1336
    # LEGACY
    def _convert_v1_inputs(
1337
        self,
1338
1339
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1340
1341
    ):
        # skip_tokenizer_init is now checked in engine
1342

1343
1344
1345
1346
1347
1348
1349
1350
1351
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1352
1353
1354
1355
1356
1357
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1358
1359
        if prompts is not None:
            num_requests = len(prompts)
1360
        elif prompt_token_ids is not None:
1361
            num_requests = len(prompt_token_ids)
1362
        parsed_prompts: list[PromptType] = []
1363
        for i in range(num_requests):
1364
            item: PromptType
1365

1366
            if prompts is not None:
1367
1368
1369
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1370
            else:
1371
                raise AssertionError
1372

1373
            parsed_prompts.append(item)
1374

1375
        return parsed_prompts
1376
1377
1378

    def _validate_and_add_requests(
        self,
1379
        prompts: Union[PromptType, Sequence[PromptType]],
1380
1381
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1382
1383
        *,
        use_tqdm: bool,
1384
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1385
        prompt_adapter_request: Optional[PromptAdapterRequest],
1386
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1387
        guided_options: Optional[GuidedDecodingRequest] = None,
1388
        priority: Optional[list[int]] = None,
1389
    ) -> None:
1390
1391
1392
1393
1394
1395
1396
1397
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1398
        if isinstance(prompts, (str, dict)):
1399
            # Convert a single prompt to a list.
1400
            prompts = [prompts]
1401

1402
        num_requests = len(prompts)
1403
1404
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1405
                             "must be the same.")
1406
1407
1408
1409
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1410

1411
1412
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1413
                self._add_guided_params(sp, guided_options)
1414
1415
1416

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1417

Zhuohan Li's avatar
Zhuohan Li committed
1418
        # Add requests to the engine.
1419
1420
1421
1422
1423
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1424
            self._add_request(
1425
                prompt,
1426
                params[i] if isinstance(params, Sequence) else params,
1427
                tokenization_kwargs=tokenization_kwargs,
1428
1429
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1430
                prompt_adapter_request=prompt_adapter_request,
1431
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1432
            )
1433

1434
    def _add_request(
nunjunj's avatar
nunjunj committed
1435
        self,
1436
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1437
        params: Union[SamplingParams, PoolingParams],
1438
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1439
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1440
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1441
        priority: int = 0,
1442
1443
    ) -> None:
        request_id = str(next(self.request_counter))
1444
1445
        self.llm_engine.add_request(
            request_id,
1446
            prompt,
1447
1448
            params,
            lora_request=lora_request,
1449
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1450
            prompt_adapter_request=prompt_adapter_request,
1451
            priority=priority,
nunjunj's avatar
nunjunj committed
1452
        )
1453

1454
    def _add_guided_params(
1455
1456
1457
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1458
1459
1460
1461
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1462
            raise ValueError("Cannot set both guided_options_request and "
1463
1464
1465
1466
1467
1468
1469
1470
1471
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1472
1473
1474
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1475
1476
        return params

1477
    def _run_engine(
1478
            self, *, use_tqdm: bool
1479
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1480
1481
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1482
            num_requests = self.llm_engine.get_num_unfinished_requests()
1483
1484
1485
1486
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1487
1488
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1489
            )
1490

Zhuohan Li's avatar
Zhuohan Li committed
1491
        # Run the engine.
1492
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1493
1494
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1495
1496
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1497
            for output in step_outputs:
1498
                if output.finished:
1499
1500
                    outputs.append(output)
                    if use_tqdm:
1501
1502
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1503
                            n = len(output.outputs)
1504
                            assert output.prompt_token_ids is not None
1505
                            total_in_toks += len(output.prompt_token_ids) * n
1506
1507
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1508
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1509
1510
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1511
1512
1513
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1514
                            pbar.update(n)
1515
1516
                        else:
                            pbar.update(1)
1517

1518
1519
        if use_tqdm:
            pbar.close()
1520
1521
1522
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1523
        return sorted(outputs, key=lambda x: int(x.request_id))