llm.py 66.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
8
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
9

10
import cloudpickle
11
import torch.nn as nn
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, deprecated
14

15
16
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
17
18
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
19
20
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
21
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
22
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
23
                                         ChatTemplateContentFormatOption,
24
25
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
26
27
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
28
29
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
30
from vllm.entrypoints.utils import _validate_truncation_size
31
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
32
from vllm.inputs.parse import parse_and_batch_prompt
33
from vllm.logger import init_logger
34
from vllm.lora.request import LoRARequest
35
36
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
37
from vllm.model_executor.layers.quantization import QuantizationMethods
38
39
40
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
41
from vllm.pooling_params import PoolingParams
42
from vllm.prompt_adapter.request import PromptAdapterRequest
43
44
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
45
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
46
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
47
from vllm.usage.usage_lib import UsageContext
48
49
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
50

51
52
53
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

54
55
logger = init_logger(__name__)

56
57
_R = TypeVar("_R", default=Any)

58
59

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
70
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
71
72
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
73
74
75
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
76
77
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
78
79
80
81
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
89
        quantization: The method used to quantize the model weights. Currently,
90
            we support "awq", "gptq", and "fp8" (experimental).
91
92
93
94
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
95
96
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
97
98
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
99
100
101
102
103
104
105
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
106
107
108
109
110
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
111
112
113
114
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
115
116
117
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
118
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
119
            When a sequence has context length larger than this, we fall back
120
121
122
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
123
124
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
125
126
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
127
        hf_token: The token to use as HTTP bearer authorization for remote files
128
            . If `True`, will use the token generated when running
129
            `huggingface-cli login` (stored in `~/.huggingface`).
130
131
132
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
133
134
135
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
136
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
137

138
139
    Note:
        This class is intended to be used for offline inference. For online
140
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
141
    """
142

143
    DEPRECATE_LEGACY: ClassVar[bool] = True
144
145
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

146
147
148
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
149
    [LLM.__init__][].
150
151
    """

152
153
154
155
156
157
158
159
160
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

161
162
163
164
165
166
167
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
168
169
170
    def __init__(
        self,
        model: str,
171
        tokenizer: Optional[str] = None,
172
        tokenizer_mode: TokenizerMode = "auto",
173
        skip_tokenizer_init: bool = False,
174
        trust_remote_code: bool = False,
175
        allowed_local_media_path: str = "",
176
        tensor_parallel_size: int = 1,
177
178
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
179
        revision: Optional[str] = None,
180
        tokenizer_revision: Optional[str] = None,
181
        seed: Optional[int] = None,
182
        gpu_memory_utilization: float = 0.9,
183
        swap_space: float = 4,
184
        cpu_offload_gb: float = 0,
185
        enforce_eager: bool = False,
186
        max_seq_len_to_capture: int = 8192,
187
        disable_custom_all_reduce: bool = False,
188
        disable_async_output_proc: bool = False,
189
        hf_token: Optional[Union[bool, str]] = None,
190
        hf_overrides: Optional[HfOverrides] = None,
191
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
192
193
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
194
        override_pooler_config: Optional[PoolerConfig] = None,
195
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
196
197
        **kwargs,
    ) -> None:
198
        """LLM constructor."""
199

200
201
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
202

203
204
205
206
207
208
209
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

210
211
212
        if hf_overrides is None:
            hf_overrides = {}

213
        if compilation_config is not None:
214
215
216
217
218
219
220
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
221
222
            else:
                compilation_config_instance = compilation_config
223
        else:
224
            compilation_config_instance = CompilationConfig()
225

Zhuohan Li's avatar
Zhuohan Li committed
226
        engine_args = EngineArgs(
227
            model=model,
228
            task=task,
229
            tokenizer=tokenizer,
230
            tokenizer_mode=tokenizer_mode,
231
            skip_tokenizer_init=skip_tokenizer_init,
232
            trust_remote_code=trust_remote_code,
233
            allowed_local_media_path=allowed_local_media_path,
234
235
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
236
            quantization=quantization,
237
            revision=revision,
238
            tokenizer_revision=tokenizer_revision,
239
240
241
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
242
            cpu_offload_gb=cpu_offload_gb,
243
            enforce_eager=enforce_eager,
244
            max_seq_len_to_capture=max_seq_len_to_capture,
245
            disable_custom_all_reduce=disable_custom_all_reduce,
246
            disable_async_output_proc=disable_async_output_proc,
247
            hf_token=hf_token,
248
            hf_overrides=hf_overrides,
249
            mm_processor_kwargs=mm_processor_kwargs,
250
            override_pooler_config=override_pooler_config,
251
            compilation_config=compilation_config_instance,
252
253
            **kwargs,
        )
254
255
256
257
258

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
259

260
        self.request_counter = Counter()
261
        self.default_sampling_params: Union[dict[str, Any], None] = None
262

263
264
265
266
267
268
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
269
270

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
271
        tokenizer_group = self.llm_engine.get_tokenizer_group()
272

273
274
275
276
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
277
            tokenizer_group.tokenizer = tokenizer
278
        else:
279
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
280

281
    def get_default_sampling_params(self) -> SamplingParams:
282
283
284
285
286
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
287
288
        return SamplingParams()

289
290
291
292
293
294
295
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
296
        *,
297
        use_tqdm: bool = True,
298
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
299
300
301
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
302
    ) -> list[RequestOutput]:
303
304
        ...

305
    @overload  # LEGACY: single (prompt + optional token ids)
306
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
307
308
309
310
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
311
312
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
313
        use_tqdm: bool = True,
314
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
315
316
317
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
318
    ) -> list[RequestOutput]:
319
320
321
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
322
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
323
324
    def generate(
        self,
325
        prompts: list[str],
326
        sampling_params: Optional[Union[SamplingParams,
327
328
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
329
        use_tqdm: bool = True,
330
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
331
332
333
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
334
    ) -> list[RequestOutput]:
335
336
337
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
338
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
339
340
341
342
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
343
                                        list[SamplingParams]]] = None,
344
        *,
345
        prompt_token_ids: list[int],
346
        use_tqdm: bool = True,
347
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
348
349
350
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
351
    ) -> list[RequestOutput]:
352
353
354
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
355
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
356
357
    def generate(
        self,
358
        prompts: Optional[list[str]] = None,
359
        sampling_params: Optional[Union[SamplingParams,
360
                                        list[SamplingParams]]] = None,
361
        *,
362
        prompt_token_ids: list[list[int]],
363
        use_tqdm: bool = True,
364
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
365
366
367
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
368
    ) -> list[RequestOutput]:
369
370
371
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
372
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
373
374
375
376
    def generate(
        self,
        prompts: None,
        sampling_params: None,
377
        prompt_token_ids: Union[list[int], list[list[int]]],
378
        use_tqdm: bool = True,
379
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
380
381
382
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
383
    ) -> list[RequestOutput]:
384
385
        ...

nunjunj's avatar
nunjunj committed
386
387
388
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
389
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
390
    )
391
392
    def generate(
        self,
393
        prompts: Union[Union[PromptType, Sequence[PromptType]],
394
                       Optional[Union[str, list[str]]]] = None,
395
396
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
397
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
398
        use_tqdm: bool = True,
399
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
400
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
401
        guided_options_request: Optional[Union[LLMGuidedOptions,
402
                                               GuidedDecodingRequest]] = None,
403
404
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
        """Generates the completions for the input prompts.

407
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
408
409
410
411
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
412
            prompts: The prompts to the LLM. You may pass a sequence of prompts
413
                for batch inference. See [PromptType][vllm.inputs.PromptType]
414
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
415
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
416
417
418
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
419
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
420
            use_tqdm: Whether to use tqdm to display the progress bar.
421
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
422
            prompt_adapter_request: Prompt Adapter request to use for
423
                generation, if any.
424
425
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
426
427

        Returns:
428
            A list of `RequestOutput` objects containing the
429
            generated completions in the same order as the input prompts.
430

431
432
433
434
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
435
        """
436
        runner_type = self.llm_engine.model_config.runner_type
437
        if runner_type not in ["generate", "transcription"]:
438
            messages = [
439
                "LLM.generate() is only supported for (conditional) generation "
440
441
442
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

443
444
445
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
446
                messages.append(
447
448
449
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
450
451

            raise ValueError(" ".join(messages))
452

453
        if prompt_token_ids is not None:
454
            parsed_prompts = self._convert_v1_inputs(
455
                prompts=cast(Optional[Union[str, list[str]]], prompts),
456
457
458
                prompt_token_ids=prompt_token_ids,
            )
        else:
459
460
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
461

462
463
464
465
466
467
468
469
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

470
471
        if sampling_params is None:
            # Use default sampling params.
472
            sampling_params = self.get_default_sampling_params()
473

474
        self._validate_and_add_requests(
475
            prompts=parsed_prompts,
476
            params=sampling_params,
477
            use_tqdm=use_tqdm,
478
            lora_request=lora_request,
479
            prompt_adapter_request=prompt_adapter_request,
480
            guided_options=guided_options_request,
481
482
            priority=priority,
        )
483

484
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
485
        return self.engine_class.validate_outputs(outputs, RequestOutput)
486

487
    def collective_rpc(self,
488
                       method: Union[str, Callable[..., _R]],
489
                       timeout: Optional[float] = None,
490
491
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
492
493
494
495
496
497
498
499
500
501
502
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
503
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
504
505
506
507
508
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
509

510
511
512
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
513
        """
514
515

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
516
517

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
518
        """
519
520
        Run a function directly on the model inside each worker,
        returning the result for each of them.
521
        """
522
523
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
524

525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")
            return lora_request

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

542
543
    def beam_search(
        self,
544
        prompts: list[Union[TokensPrompt, TextPrompt]],
545
        params: BeamSearchParams,
546
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
547
    ) -> list[BeamSearchOutput]:
548
549
550
551
552
553
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
554
            params: The beam search parameters.
555
            lora_request: LoRA request to use for generation, if any.
556
        """
557
558
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
559
560
561
562
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
563
564
        length_penalty = params.length_penalty

565
566
567
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

568
569
570
571
        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
572

573
574
575
576
577
578
579
580
581
582
583
584
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
585

586
587
588
589
590
591
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
592
                                            temperature=temperature)
593
        instances: list[BeamSearchInstance] = []
594

595
        for lora_req, prompt in zip(lora_requests, prompts):
596
597
598
599
600
601
602
603
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

604
605
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
606
607
608
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
609

610
            instances.append(
611
612
613
614
615
616
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
617
618

        for _ in range(max_tokens):
619
            all_beams: list[BeamSearchSequence] = list(
620
621
622
623
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
624
            instance_start_and_end: list[tuple[int, int]] = list(
625
626
627
628
629
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

630
631
632
633
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
634
635
636
637
638

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
639
640
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
657
                                logprobs=current_beam.logprobs + [logprobs],
658
                                lora_request=current_beam.lora_request,
659
                                cum_logprob=current_beam.cum_logprob +
660
661
662
663
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
664
665
666
667
668
669
670

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
671
                                      key=sort_beams_key,
672
673
674
675
676
677
678
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
679
                                      key=sort_beams_key,
680
681
682
683
684
685
686
687
688
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
689
690
    def chat(
        self,
691
692
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
693
        sampling_params: Optional[Union[SamplingParams,
694
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
695
696
697
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
698
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
699
        add_generation_prompt: bool = True,
700
        continue_final_message: bool = False,
701
        tools: Optional[list[dict[str, Any]]] = None,
702
        chat_template_kwargs: Optional[dict[str, Any]] = None,
703
704
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
705
        """
706
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
707

708
        The chat conversation is converted into a text prompt using the
709
        tokenizer and calls the [generate][] method to generate the
710
711
712
713
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
714
715

        Args:
716
717
            messages: A list of conversations or a single conversation.

718
719
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
720

nunjunj's avatar
nunjunj committed
721
722
723
724
725
726
727
728
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
729
                If not provided, the model's default chat template will be used.
730
731
            chat_template_content_format: The format to render message content.

732
733
734
735
736
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
737

738
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
739
                to each message.
740
            continue_final_message: If True, continues the final message in
741
                the conversation instead of starting a new one. Cannot be
742
                `True` if `add_generation_prompt` is also `True`.
743
744
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
745
746
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
747
748

        Returns:
749
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
750
751
            responses in the same order as the input messages.
        """
752
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
753

754
755
        # Handle multi and single conversations
        if is_list_of(messages, list):
756
757
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
758
                                    messages)
759
        else:
760
            # messages is list[...]
761
            list_of_messages = [
762
                cast(list[ChatCompletionMessageParam], messages)
763
            ]
764

765
        tokenizer = self.get_tokenizer(lora_request)
766
767
768
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
769
            tools,
770
771
            chat_template_content_format,
            tokenizer,
772
            model_config=model_config,
773
774
        )

775
776
777
778
779
780
781
782
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

783
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
784
785

        for msgs in list_of_messages:
786
787
788
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
789
            conversation, mm_data = parse_chat_messages(
790
791
792
793
794
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
795
796

            if isinstance(tokenizer, MistralTokenizer):
797
                prompt_token_ids = apply_mistral_chat_template(
798
799
                    tokenizer,
                    messages=msgs,
800
                    **_chat_template_kwargs,
801
802
                )
            else:
803
                prompt_str = apply_hf_chat_template(
804
                    tokenizer=tokenizer,
805
                    conversation=conversation,
806
                    model_config=model_config,
807
                    **_chat_template_kwargs,
808
                )
809
810
811
812
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
813

814
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
815
816
817
818

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

819
820
821
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

822
            prompts.append(prompt)
823

nunjunj's avatar
nunjunj committed
824
        return self.generate(
825
            prompts,
826
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
827
828
829
830
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

831
832
833
834
835
836
837
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
838
        *,
839
        truncate_prompt_tokens: Optional[int] = None,
840
        use_tqdm: bool = True,
841
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
842
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
843
    ) -> list[PoolingRequestOutput]:
844
845
        ...

846
    @overload  # LEGACY: single (prompt + optional token ids)
847
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
848
849
850
851
852
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
853
        prompt_token_ids: Optional[list[int]] = None,
854
        truncate_prompt_tokens: Optional[int] = None,
855
        use_tqdm: bool = True,
856
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
857
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
858
    ) -> list[PoolingRequestOutput]:
859
        ...
860

861
    @overload  # LEGACY: multi (prompt + optional token ids)
862
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
863
864
    def encode(
        self,
865
        prompts: list[str],
866
        pooling_params: Optional[Union[PoolingParams,
867
                                       Sequence[PoolingParams]]] = None,
868
        prompt_token_ids: Optional[list[list[int]]] = None,
869
        truncate_prompt_tokens: Optional[int] = None,
870
        use_tqdm: bool = True,
871
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
872
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
873
    ) -> list[PoolingRequestOutput]:
874
875
876
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
877
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
878
879
880
881
882
883
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
884
        prompt_token_ids: list[int],
885
        truncate_prompt_tokens: Optional[int] = None,
886
        use_tqdm: bool = True,
887
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
888
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
889
    ) -> list[PoolingRequestOutput]:
890
891
892
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
893
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
894
895
    def encode(
        self,
896
        prompts: Optional[list[str]] = None,
897
898
899
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
900
        prompt_token_ids: list[list[int]],
901
        truncate_prompt_tokens: Optional[int] = None,
902
        use_tqdm: bool = True,
903
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
904
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
905
    ) -> list[PoolingRequestOutput]:
906
907
908
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
909
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
910
911
912
913
    def encode(
        self,
        prompts: None,
        pooling_params: None,
914
        prompt_token_ids: Union[list[int], list[list[int]]],
915
        truncate_prompt_tokens: Optional[int] = None,
916
        use_tqdm: bool = True,
917
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
918
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
919
    ) -> list[PoolingRequestOutput]:
920
921
        ...

nunjunj's avatar
nunjunj committed
922
923
924
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
925
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
926
    )
927
928
    def encode(
        self,
929
        prompts: Union[Union[PromptType, Sequence[PromptType]],
930
                       Optional[Union[str, list[str]]]] = None,
931
932
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
933
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
934
        truncate_prompt_tokens: Optional[int] = None,
935
        use_tqdm: bool = True,
936
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
937
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
938
    ) -> list[PoolingRequestOutput]:
939
940
        """Apply pooling to the hidden states corresponding to the input
        prompts.
941

942
        This class automatically batches the given prompts, considering
943
944
945
946
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
947
            prompts: The prompts to the LLM. You may pass a sequence of prompts
948
                for batch inference. See [PromptType][vllm.inputs.PromptType]
949
                for more details about the format of each prompts.
950
951
952
953
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
954
            prompt_adapter_request: Prompt Adapter request to use for
955
                generation, if any.
956
957

        Returns:
958
            A list of `PoolingRequestOutput` objects containing the
959
            pooled hidden states in the same order as the input prompts.
960

961
962
963
964
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
965
        """
966
967
968
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
969

970
971
972
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
973
                messages.append(
974
975
976
977
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
978
979

            raise ValueError(" ".join(messages))
980

981
        if prompt_token_ids is not None:
982
            parsed_prompts = self._convert_v1_inputs(
983
                prompts=cast(Optional[Union[str, list[str]]], prompts),
984
985
986
                prompt_token_ids=prompt_token_ids,
            )
        else:
987
988
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
989

990
991
992
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
993
994
995
996
997
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
998

999
1000
1001
1002
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

1003
        self._validate_and_add_requests(
1004
            prompts=parsed_prompts,
1005
            params=pooling_params,
1006
            use_tqdm=use_tqdm,
1007
            lora_request=lora_request,
1008
            tokenization_kwargs=tokenization_kwargs,
1009
            prompt_adapter_request=prompt_adapter_request,
1010
1011
        )

1012
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1013
        return self.engine_class.validate_outputs(outputs,
1014
                                                  PoolingRequestOutput)
1015

1016
1017
1018
1019
1020
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1021
        truncate_prompt_tokens: Optional[int] = None,
1022
        use_tqdm: bool = True,
1023
1024
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1025
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1026
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1027
    ) -> list[EmbeddingRequestOutput]:
1028
1029
1030
1031
1032
1033
1034
1035
1036
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1037
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1038
                for more details about the format of each prompts.
1039
1040
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1041
1042
1043
1044
1045
1046
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1047
            A list of `EmbeddingRequestOutput` objects containing the
1048
1049
1050
1051
1052
1053
1054
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1055
                            truncate_prompt_tokens=truncate_prompt_tokens,
1056
                            use_tqdm=use_tqdm,
1057
                            pooling_params=pooling_params,
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1069
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1070
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1071
    ) -> list[ClassificationRequestOutput]:
1072
1073
1074
1075
1076
1077
1078
1079
1080
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1081
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1082
1083
1084
1085
1086
1087
1088
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1089
            A list of `ClassificationRequestOutput` objects containing the
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1103
1104
1105
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1106
1107
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1108
1109
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1110
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1111
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1112
    ) -> list[ScoringRequestOutput]:
1113

1114
        encoded_output: list[PoolingRequestOutput] = self.encode(
1115
            text_1 + text_2,
1116
            truncate_prompt_tokens=truncate_prompt_tokens,
1117
1118
1119
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1120

1121
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1122
            0:len(text_1)]
1123
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1124
            len(text_1):]
1125
1126
1127
1128

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1129
1130
1131
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1132
1133
1134
1135
1136
1137
1138

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1139
        tokenizer: AnyTokenizer,
1140
1141
        text_1: list[str],
        text_2: list[str],
1142
1143
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1144
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1145
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1146
    ) -> list[ScoringRequestOutput]:
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1159
        tokenization_kwargs: dict[str, Any] = {}
1160
1161
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1177
            use_tqdm=use_tqdm,
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1188
1189
1190
1191
1192
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1193
        *,
1194
1195
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1196
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1197
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1198
    ) -> list[ScoringRequestOutput]:
1199
        """Generate similarity scores for all pairs `<text,text_pair>`.
1200

1201
1202
1203
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1204
        The input pairs are used to build a list of prompts for the
1205
1206
1207
1208
1209
1210
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1211
                case it has to have the same length as the `text_2` list
1212
            text_2: The texts to pair with the query to form the input
1213
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1214
1215
1216
1217
1218
1219
1220
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1221
            A list of `ScoringRequestOutput` objects containing the
1222
1223
            generated scores in the same order as the input prompts.
        """
1224
1225
1226
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1227

1228
1229
1230
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1231
                messages.append(
1232
1233
1234
1235
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1236
1237
1238

            raise ValueError(" ".join(messages))

1239
        if self.llm_engine.model_config.task not in ("embed", "score"):
1240
            raise ValueError(
1241
                "Score API is only enabled for `--task embed or --task score`")
1242
1243
1244
1245

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1246
1247
        tokenizer = self.llm_engine.get_tokenizer()

1248
1249
1250
1251
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1252
                                     "supported for scoring")
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1264
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1265
1266
1267
1268

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1269
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1270

1271
        _validate_score_input_lens(input_text_1, input_text_2)
1272

1273
        if self.llm_engine.model_config.is_cross_encoder:
1274
1275
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1276
1277
1278
1279
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1280
1281
1282
1283
1284
1285
1286
1287
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1288

1289
1290
1291
1292
1293
1294
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1295
1296
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1297

1298
1299
1300
1301
1302
1303
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1316
        """
1317
        self.reset_prefix_cache()
1318
1319
        self.llm_engine.sleep(level=level)

1320
    def wake_up(self, tags: Optional[list[str]] = None):
1321
        """
1322
        Wake up the engine from sleep mode. See the [sleep][] method
1323
1324
1325
1326
1327
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1328
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1329
1330
1331
1332
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1333

1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1348
1349
    # LEGACY
    def _convert_v1_inputs(
1350
        self,
1351
1352
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1353
1354
    ):
        # skip_tokenizer_init is now checked in engine
1355

1356
1357
1358
1359
1360
1361
1362
1363
1364
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1365
1366
1367
1368
1369
1370
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1371
1372
        if prompts is not None:
            num_requests = len(prompts)
1373
        elif prompt_token_ids is not None:
1374
            num_requests = len(prompt_token_ids)
1375
        parsed_prompts: list[PromptType] = []
1376
        for i in range(num_requests):
1377
            item: PromptType
1378

1379
            if prompts is not None:
1380
1381
1382
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1383
            else:
1384
                raise AssertionError
1385

1386
            parsed_prompts.append(item)
1387

1388
        return parsed_prompts
1389
1390
1391

    def _validate_and_add_requests(
        self,
1392
        prompts: Union[PromptType, Sequence[PromptType]],
1393
1394
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1395
1396
        *,
        use_tqdm: bool,
1397
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1398
        prompt_adapter_request: Optional[PromptAdapterRequest],
1399
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1400
        guided_options: Optional[GuidedDecodingRequest] = None,
1401
        priority: Optional[list[int]] = None,
1402
    ) -> None:
1403
1404
1405
1406
1407
1408
1409
1410
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1411
        if isinstance(prompts, (str, dict)):
1412
            # Convert a single prompt to a list.
1413
            prompts = [prompts]
1414

1415
        num_requests = len(prompts)
1416
1417
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1418
                             "must be the same.")
1419
1420
1421
1422
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1423

1424
1425
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1426
                self._add_guided_params(sp, guided_options)
1427
1428
1429

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1430

Zhuohan Li's avatar
Zhuohan Li committed
1431
        # Add requests to the engine.
1432
1433
1434
1435
1436
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1437
            self._add_request(
1438
                prompt,
1439
                params[i] if isinstance(params, Sequence) else params,
1440
                tokenization_kwargs=tokenization_kwargs,
1441
1442
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1443
                prompt_adapter_request=prompt_adapter_request,
1444
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1445
            )
1446

1447
    def _add_request(
nunjunj's avatar
nunjunj committed
1448
        self,
1449
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1450
        params: Union[SamplingParams, PoolingParams],
1451
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1452
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1453
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1454
        priority: int = 0,
1455
1456
    ) -> None:
        request_id = str(next(self.request_counter))
1457
1458
        self.llm_engine.add_request(
            request_id,
1459
            prompt,
1460
1461
            params,
            lora_request=lora_request,
1462
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1463
            prompt_adapter_request=prompt_adapter_request,
1464
            priority=priority,
nunjunj's avatar
nunjunj committed
1465
        )
1466

1467
    def _add_guided_params(
1468
1469
1470
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1471
1472
1473
1474
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1475
            raise ValueError("Cannot set both guided_options_request and "
1476
1477
1478
1479
1480
1481
1482
1483
1484
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1485
1486
1487
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1488
1489
        return params

1490
    def _run_engine(
1491
            self, *, use_tqdm: bool
1492
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1493
1494
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1495
            num_requests = self.llm_engine.get_num_unfinished_requests()
1496
1497
1498
1499
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1500
1501
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1502
            )
1503

Zhuohan Li's avatar
Zhuohan Li committed
1504
        # Run the engine.
1505
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1506
1507
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1508
1509
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1510
            for output in step_outputs:
1511
                if output.finished:
1512
1513
                    outputs.append(output)
                    if use_tqdm:
1514
1515
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1516
                            n = len(output.outputs)
1517
                            assert output.prompt_token_ids is not None
1518
                            total_in_toks += len(output.prompt_token_ids) * n
1519
1520
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1521
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1522
1523
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1524
1525
1526
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1527
                            pbar.update(n)
1528
1529
                        else:
                            pbar.update(1)
1530

1531
1532
        if use_tqdm:
            pbar.close()
1533
1534
1535
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1536
        return sorted(outputs, key=lambda x: int(x.request_id))