llm.py 69.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
18
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
19
20
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
21
22
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
23
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
24
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
25
                                         ChatTemplateContentFormatOption,
26
27
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
28
29
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
30
31
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
32
from vllm.entrypoints.utils import _validate_truncation_size
33
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
34
from vllm.inputs.parse import parse_and_batch_prompt
35
from vllm.logger import init_logger
36
from vllm.lora.request import LoRARequest
37
38
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
39
from vllm.model_executor.layers.quantization import QuantizationMethods
40
41
42
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
43
from vllm.pooling_params import PoolingParams
44
from vllm.prompt_adapter.request import PromptAdapterRequest
45
46
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
47
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
48
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
49
from vllm.usage.usage_lib import UsageContext
50
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
51

52
53
54
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

55
56
logger = init_logger(__name__)

57
58
_R = TypeVar("_R", default=Any)

59
60

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
64
65
66
67
68
69
70
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
71
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
72
73
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
74
75
76
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
77
78
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
79
80
81
82
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
83
84
85
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
89
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
90
        quantization: The method used to quantize the model weights. Currently,
91
            we support "awq", "gptq", and "fp8" (experimental).
92
93
94
95
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
96
97
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
98
99
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
100
101
102
103
104
105
106
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
107
108
109
110
111
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
112
113
114
115
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
116
117
118
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
119
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
120
            When a sequence has context length larger than this, we fall back
121
122
123
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
124
125
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
126
127
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
128
        hf_token: The token to use as HTTP bearer authorization for remote files
129
            . If `True`, will use the token generated when running
130
            `huggingface-cli login` (stored in `~/.huggingface`).
131
132
133
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
134
135
136
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
137
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
138

139
140
    Note:
        This class is intended to be used for offline inference. For online
141
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
142
    """
143

144
    DEPRECATE_LEGACY: ClassVar[bool] = True
145
146
147
148
149
150
151
152
153
154
155
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

156
157
158
    def __init__(
        self,
        model: str,
159
160
        *,
        task: TaskOption = "auto",
161
        tokenizer: Optional[str] = None,
162
        tokenizer_mode: TokenizerMode = "auto",
163
        skip_tokenizer_init: bool = False,
164
        trust_remote_code: bool = False,
165
        allowed_local_media_path: str = "",
166
        tensor_parallel_size: int = 1,
167
168
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
169
        revision: Optional[str] = None,
170
        tokenizer_revision: Optional[str] = None,
171
        seed: Optional[int] = None,
172
        gpu_memory_utilization: float = 0.9,
173
        swap_space: float = 4,
174
        cpu_offload_gb: float = 0,
175
        enforce_eager: bool = False,
176
        max_seq_len_to_capture: int = 8192,
177
        disable_custom_all_reduce: bool = False,
178
        disable_async_output_proc: bool = False,
179
        hf_token: Optional[Union[bool, str]] = None,
180
        hf_overrides: Optional[HfOverrides] = None,
181
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
182
        override_pooler_config: Optional[PoolerConfig] = None,
183
184
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
185
186
        **kwargs,
    ) -> None:
187
        """LLM constructor."""
188

189
190
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
191

192
193
194
195
196
197
198
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

216
217
218
        if hf_overrides is None:
            hf_overrides = {}

219
        if compilation_config is not None:
220
221
222
223
224
225
226
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
227
228
            else:
                compilation_config_instance = compilation_config
229
        else:
230
            compilation_config_instance = CompilationConfig()
231

Zhuohan Li's avatar
Zhuohan Li committed
232
        engine_args = EngineArgs(
233
            model=model,
234
            task=task,
235
            tokenizer=tokenizer,
236
            tokenizer_mode=tokenizer_mode,
237
            skip_tokenizer_init=skip_tokenizer_init,
238
            trust_remote_code=trust_remote_code,
239
            allowed_local_media_path=allowed_local_media_path,
240
241
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
242
            quantization=quantization,
243
            revision=revision,
244
            tokenizer_revision=tokenizer_revision,
245
246
247
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
248
            cpu_offload_gb=cpu_offload_gb,
249
            enforce_eager=enforce_eager,
250
            max_seq_len_to_capture=max_seq_len_to_capture,
251
            disable_custom_all_reduce=disable_custom_all_reduce,
252
            disable_async_output_proc=disable_async_output_proc,
253
            hf_token=hf_token,
254
            hf_overrides=hf_overrides,
255
            mm_processor_kwargs=mm_processor_kwargs,
256
            override_pooler_config=override_pooler_config,
257
            compilation_config=compilation_config_instance,
258
259
            **kwargs,
        )
260
261
262
263
264

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
265

266
        self.request_counter = Counter()
267
        self.default_sampling_params: Union[dict[str, Any], None] = None
268

269
270
271
272
273
274
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
275
276

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
277
        tokenizer_group = self.llm_engine.get_tokenizer_group()
278

279
280
281
282
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
283
            tokenizer_group.tokenizer = tokenizer
284
        else:
285
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
286

287
    def get_default_sampling_params(self) -> SamplingParams:
288
289
290
291
292
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
293
294
        return SamplingParams()

295
296
297
298
299
300
301
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
302
        *,
303
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
304
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
305
306
307
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
308
    ) -> list[RequestOutput]:
309
310
        ...

311
    @overload  # LEGACY: single (prompt + optional token ids)
312
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
313
314
315
316
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
317
318
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
319
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
320
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
321
322
323
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
324
    ) -> list[RequestOutput]:
325
326
327
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
328
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
329
330
    def generate(
        self,
331
        prompts: list[str],
332
        sampling_params: Optional[Union[SamplingParams,
333
334
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
335
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
336
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
337
338
339
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
340
    ) -> list[RequestOutput]:
341
342
343
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
344
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
345
346
347
348
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
349
                                        list[SamplingParams]]] = None,
350
        *,
351
        prompt_token_ids: list[int],
352
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
353
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
354
355
356
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
357
    ) -> list[RequestOutput]:
358
359
360
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
361
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
362
363
    def generate(
        self,
364
        prompts: Optional[list[str]] = None,
365
        sampling_params: Optional[Union[SamplingParams,
366
                                        list[SamplingParams]]] = None,
367
        *,
368
        prompt_token_ids: list[list[int]],
369
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
370
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
371
372
373
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
374
    ) -> list[RequestOutput]:
375
376
377
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
378
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
379
380
381
382
    def generate(
        self,
        prompts: None,
        sampling_params: None,
383
        prompt_token_ids: Union[list[int], list[list[int]]],
384
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
385
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
386
387
388
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
389
    ) -> list[RequestOutput]:
390
391
        ...

nunjunj's avatar
nunjunj committed
392
393
394
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
395
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
396
    )
397
398
    def generate(
        self,
399
        prompts: Union[Union[PromptType, Sequence[PromptType]],
400
                       Optional[Union[str, list[str]]]] = None,
401
402
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
403
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
404
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
405
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
406
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
407
        guided_options_request: Optional[Union[LLMGuidedOptions,
408
                                               GuidedDecodingRequest]] = None,
409
410
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
411
412
        """Generates the completions for the input prompts.

413
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
414
415
416
417
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
418
            prompts: The prompts to the LLM. You may pass a sequence of prompts
419
                for batch inference. See [PromptType][vllm.inputs.PromptType]
420
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
421
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
422
423
424
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
425
                prompts and it is paired one by one with the prompt.
426
427
428
429
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
430
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
431
            prompt_adapter_request: Prompt Adapter request to use for
432
                generation, if any.
433
434
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
435
436

        Returns:
437
            A list of `RequestOutput` objects containing the
438
            generated completions in the same order as the input prompts.
439

440
441
442
443
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
444
        """
445
        runner_type = self.llm_engine.model_config.runner_type
446
        if runner_type not in ["generate", "transcription"]:
447
            messages = [
448
                "LLM.generate() is only supported for (conditional) generation "
449
450
451
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

452
453
454
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
455
                messages.append(
456
457
458
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
459
460

            raise ValueError(" ".join(messages))
461

462
        if prompt_token_ids is not None:
463
            parsed_prompts = self._convert_v1_inputs(
464
                prompts=cast(Optional[Union[str, list[str]]], prompts),
465
466
467
                prompt_token_ids=prompt_token_ids,
            )
        else:
468
469
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
470

471
472
473
474
475
476
477
478
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

479
480
        if sampling_params is None:
            # Use default sampling params.
481
            sampling_params = self.get_default_sampling_params()
482

483
        self._validate_and_add_requests(
484
            prompts=parsed_prompts,
485
            params=sampling_params,
486
            use_tqdm=use_tqdm,
487
            lora_request=lora_request,
488
            prompt_adapter_request=prompt_adapter_request,
489
            guided_options=guided_options_request,
490
491
            priority=priority,
        )
492

493
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
494
        return self.engine_class.validate_outputs(outputs, RequestOutput)
495

496
    def collective_rpc(self,
497
                       method: Union[str, Callable[..., _R]],
498
                       timeout: Optional[float] = None,
499
500
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
501
502
503
504
505
506
507
508
509
510
511
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
512
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
513
514
515
516
517
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
518

519
520
521
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
522
        """
523
524

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
525
526

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
527
        """
528
529
        Run a function directly on the model inside each worker,
        returning the result for each of them.
530
        """
531
532
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
533

534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

550
551
    def beam_search(
        self,
552
        prompts: list[Union[TokensPrompt, TextPrompt]],
553
        params: BeamSearchParams,
554
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
555
        use_tqdm: bool = False,
556
    ) -> list[BeamSearchOutput]:
557
558
559
560
561
562
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
563
            params: The beam search parameters.
564
            lora_request: LoRA request to use for generation, if any.
565
            use_tqdm: Whether to use tqdm to display the progress bar.
566
        """
567
568
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
569
570
571
572
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
573
574
        length_penalty = params.length_penalty

575
576
577
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

578
579
580
581
        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
582

583
584
585
586
587
588
589
590
591
592
593
594
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
595

596
597
598
599
600
601
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
602
                                            temperature=temperature)
603
        instances: list[BeamSearchInstance] = []
604

605
        for lora_req, prompt in zip(lora_requests, prompts):
606
607
608
609
610
611
612
613
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

614
615
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
616
617
618
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
619

620
            instances.append(
621
622
623
624
625
626
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
627

628
629
630
631
632
633
634
635
636
637
638
639
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
640
            all_beams: list[BeamSearchSequence] = list(
641
642
643
644
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
645
            instance_start_and_end: list[tuple[int, int]] = list(
646
647
648
649
650
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

651
652
653
654
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
655
656
657
658
659

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
660
661
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
678
                                logprobs=current_beam.logprobs + [logprobs],
679
                                lora_request=current_beam.lora_request,
680
                                cum_logprob=current_beam.cum_logprob +
681
682
683
684
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
685
686
687
688
689
690
691

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
692
                                      key=sort_beams_key,
693
694
695
696
697
698
699
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
700
                                      key=sort_beams_key,
701
702
703
704
705
706
707
708
709
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
710
711
    def chat(
        self,
712
713
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
714
        sampling_params: Optional[Union[SamplingParams,
715
                                        list[SamplingParams]]] = None,
716
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
717
718
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
719
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
720
        add_generation_prompt: bool = True,
721
        continue_final_message: bool = False,
722
        tools: Optional[list[dict[str, Any]]] = None,
723
        chat_template_kwargs: Optional[dict[str, Any]] = None,
724
725
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
726
        """
727
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
728

729
        The chat conversation is converted into a text prompt using the
730
        tokenizer and calls the [generate][] method to generate the
731
732
733
734
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
735
736

        Args:
737
738
            messages: A list of conversations or a single conversation.

739
740
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
741

nunjunj's avatar
nunjunj committed
742
743
744
745
746
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
747
748
749
750
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
751
752
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
753
                If not provided, the model's default chat template will be used.
754
755
            chat_template_content_format: The format to render message content.

756
757
758
759
760
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
761

762
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
763
                to each message.
764
            continue_final_message: If True, continues the final message in
765
                the conversation instead of starting a new one. Cannot be
766
                `True` if `add_generation_prompt` is also `True`.
767
768
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
769
770
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
771
772

        Returns:
773
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
774
775
            responses in the same order as the input messages.
        """
776
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
777

778
779
        # Handle multi and single conversations
        if is_list_of(messages, list):
780
781
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
782
                                    messages)
783
        else:
784
            # messages is list[...]
785
            list_of_messages = [
786
                cast(list[ChatCompletionMessageParam], messages)
787
            ]
788

789
        tokenizer = self.get_tokenizer(lora_request)
790
791
792
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
793
            tools,
794
795
            chat_template_content_format,
            tokenizer,
796
            model_config=model_config,
797
798
        )

799
800
801
802
803
804
805
806
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

807
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
808
809

        for msgs in list_of_messages:
810
811
812
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
813
            conversation, mm_data = parse_chat_messages(
814
815
816
817
818
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
819
820

            if isinstance(tokenizer, MistralTokenizer):
821
                prompt_token_ids = apply_mistral_chat_template(
822
823
                    tokenizer,
                    messages=msgs,
824
                    **_chat_template_kwargs,
825
826
                )
            else:
827
                prompt_str = apply_hf_chat_template(
828
                    tokenizer=tokenizer,
829
                    conversation=conversation,
830
                    model_config=model_config,
831
                    **_chat_template_kwargs,
832
                )
833
834
835
836
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
837

838
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
839
840
841
842

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

843
844
845
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

846
            prompts.append(prompt)
847

nunjunj's avatar
nunjunj committed
848
        return self.generate(
849
            prompts,
850
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
851
852
853
854
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

855
856
857
858
859
860
861
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
862
        *,
863
        truncate_prompt_tokens: Optional[int] = None,
864
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
865
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
866
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
867
    ) -> list[PoolingRequestOutput]:
868
869
        ...

870
    @overload  # LEGACY: single (prompt + optional token ids)
871
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
872
873
874
875
876
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
877
        prompt_token_ids: Optional[list[int]] = None,
878
        truncate_prompt_tokens: Optional[int] = None,
879
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
880
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
881
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
882
    ) -> list[PoolingRequestOutput]:
883
        ...
884

885
    @overload  # LEGACY: multi (prompt + optional token ids)
886
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
887
888
    def encode(
        self,
889
        prompts: list[str],
890
        pooling_params: Optional[Union[PoolingParams,
891
                                       Sequence[PoolingParams]]] = None,
892
        prompt_token_ids: Optional[list[list[int]]] = None,
893
        truncate_prompt_tokens: Optional[int] = None,
894
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
895
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
896
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
897
    ) -> list[PoolingRequestOutput]:
898
899
900
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
901
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
902
903
904
905
906
907
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
908
        prompt_token_ids: list[int],
909
        truncate_prompt_tokens: Optional[int] = None,
910
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
911
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
912
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
913
    ) -> list[PoolingRequestOutput]:
914
915
916
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
917
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
918
919
    def encode(
        self,
920
        prompts: Optional[list[str]] = None,
921
922
923
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
924
        prompt_token_ids: list[list[int]],
925
        truncate_prompt_tokens: Optional[int] = None,
926
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
927
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
928
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
929
    ) -> list[PoolingRequestOutput]:
930
931
932
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
933
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
934
935
936
937
    def encode(
        self,
        prompts: None,
        pooling_params: None,
938
        prompt_token_ids: Union[list[int], list[list[int]]],
939
        truncate_prompt_tokens: Optional[int] = None,
940
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
941
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
942
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
943
    ) -> list[PoolingRequestOutput]:
944
945
        ...

nunjunj's avatar
nunjunj committed
946
947
948
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
949
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
950
    )
951
952
    def encode(
        self,
953
        prompts: Union[Union[PromptType, Sequence[PromptType]],
954
                       Optional[Union[str, list[str]]]] = None,
955
956
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
957
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
958
        truncate_prompt_tokens: Optional[int] = None,
959
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
960
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
961
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
962
    ) -> list[PoolingRequestOutput]:
963
964
        """Apply pooling to the hidden states corresponding to the input
        prompts.
965

966
        This class automatically batches the given prompts, considering
967
968
969
970
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
971
            prompts: The prompts to the LLM. You may pass a sequence of prompts
972
                for batch inference. See [PromptType][vllm.inputs.PromptType]
973
                for more details about the format of each prompts.
974
975
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
976
977
978
979
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
980
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
981
            prompt_adapter_request: Prompt Adapter request to use for
982
                generation, if any.
983
984

        Returns:
985
            A list of `PoolingRequestOutput` objects containing the
986
            pooled hidden states in the same order as the input prompts.
987

988
989
990
991
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
992
        """
993
994
995
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
996

997
998
999
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1000
                messages.append(
1001
1002
1003
1004
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1005
1006

            raise ValueError(" ".join(messages))
1007

1008
        if prompt_token_ids is not None:
1009
            parsed_prompts = self._convert_v1_inputs(
1010
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1011
1012
1013
                prompt_token_ids=prompt_token_ids,
            )
        else:
1014
1015
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1016

1017
1018
1019
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1020
1021
1022
1023
1024
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
1025

1026
1027
1028
1029
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

1030
        self._validate_and_add_requests(
1031
            prompts=parsed_prompts,
1032
            params=pooling_params,
1033
            use_tqdm=use_tqdm,
1034
            lora_request=lora_request,
1035
            tokenization_kwargs=tokenization_kwargs,
1036
            prompt_adapter_request=prompt_adapter_request,
1037
1038
        )

1039
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1040
        return self.engine_class.validate_outputs(outputs,
1041
                                                  PoolingRequestOutput)
1042

1043
1044
1045
1046
1047
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1048
        truncate_prompt_tokens: Optional[int] = None,
1049
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1050
1051
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1052
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1053
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1054
    ) -> list[EmbeddingRequestOutput]:
1055
1056
1057
1058
1059
1060
1061
1062
1063
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1064
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1065
                for more details about the format of each prompts.
1066
1067
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1068
1069
1070
1071
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1072
1073
1074
1075
1076
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1077
            A list of `EmbeddingRequestOutput` objects containing the
1078
1079
1080
1081
1082
1083
1084
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1085
                            truncate_prompt_tokens=truncate_prompt_tokens,
1086
                            use_tqdm=use_tqdm,
1087
                            pooling_params=pooling_params,
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1098
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1099
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1100
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1101
    ) -> list[ClassificationRequestOutput]:
1102
1103
1104
1105
1106
1107
1108
1109
1110
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1111
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1112
                for more details about the format of each prompts.
1113
1114
1115
1116
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1117
1118
1119
1120
1121
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1122
            A list of `ClassificationRequestOutput` objects containing the
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1136
1137
1138
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1139
1140
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1141
        truncate_prompt_tokens: Optional[int] = None,
1142
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1143
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1144
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1145
    ) -> list[ScoringRequestOutput]:
1146

1147
        encoded_output: list[PoolingRequestOutput] = self.encode(
1148
            text_1 + text_2,
1149
            truncate_prompt_tokens=truncate_prompt_tokens,
1150
1151
1152
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1153

1154
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1155
            0:len(text_1)]
1156
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1157
            len(text_1):]
1158
1159
1160
1161

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1162
1163
1164
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1165
1166
1167
1168
1169
1170
1171

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1172
        tokenizer: AnyTokenizer,
1173
1174
        text_1: list[str],
        text_2: list[str],
1175
        truncate_prompt_tokens: Optional[int] = None,
1176
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1177
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1178
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1179
    ) -> list[ScoringRequestOutput]:
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1192
        tokenization_kwargs: dict[str, Any] = {}
1193
1194
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1210
            use_tqdm=use_tqdm,
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1221
1222
1223
1224
1225
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1226
        *,
1227
        truncate_prompt_tokens: Optional[int] = None,
1228
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1229
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1230
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1231
    ) -> list[ScoringRequestOutput]:
1232
        """Generate similarity scores for all pairs `<text,text_pair>`.
1233

1234
1235
1236
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1237
        The input pairs are used to build a list of prompts for the
1238
1239
1240
1241
1242
1243
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1244
                case it has to have the same length as the `text_2` list
1245
            text_2: The texts to pair with the query to form the input
1246
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1247
                more details about the format of each prompts.
1248
1249
1250
1251
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1252
1253
1254
1255
1256
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1257
            A list of `ScoringRequestOutput` objects containing the
1258
1259
            generated scores in the same order as the input prompts.
        """
1260
1261
1262
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1263

1264
1265
1266
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1267
                messages.append(
1268
1269
1270
1271
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1272
1273
1274

            raise ValueError(" ".join(messages))

1275
        if self.llm_engine.model_config.task not in ("embed", "score"):
1276
            raise ValueError(
1277
                "Score API is only enabled for `--task embed or --task score`")
1278
1279
1280
1281

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1282
        tokenizer = self.get_tokenizer()
1283

1284
1285
1286
1287
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1288
                                     "supported for scoring")
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1300
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1301
1302
1303
1304

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1305
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1306

1307
        _validate_score_input_lens(input_text_1, input_text_2)
1308

1309
        if self.llm_engine.model_config.is_cross_encoder:
1310
1311
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1312
1313
1314
1315
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1316
1317
1318
1319
1320
1321
1322
1323
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1324

1325
1326
1327
1328
1329
1330
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1331
1332
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1333

1334
1335
1336
1337
1338
1339
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1352
        """
1353
        self.reset_prefix_cache()
1354
1355
        self.llm_engine.sleep(level=level)

1356
    def wake_up(self, tags: Optional[list[str]] = None):
1357
        """
1358
        Wake up the engine from sleep mode. See the [sleep][] method
1359
1360
1361
1362
1363
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1364
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1365
1366
1367
1368
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1369

1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1384
1385
    # LEGACY
    def _convert_v1_inputs(
1386
        self,
1387
1388
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1389
1390
    ):
        # skip_tokenizer_init is now checked in engine
1391

1392
1393
1394
1395
1396
1397
1398
1399
1400
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1401
1402
1403
1404
1405
1406
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1407
1408
        if prompts is not None:
            num_requests = len(prompts)
1409
        elif prompt_token_ids is not None:
1410
            num_requests = len(prompt_token_ids)
1411
        parsed_prompts: list[PromptType] = []
1412
        for i in range(num_requests):
1413
            item: PromptType
1414

1415
            if prompts is not None:
1416
1417
1418
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1419
            else:
1420
                raise AssertionError
1421

1422
            parsed_prompts.append(item)
1423

1424
        return parsed_prompts
1425
1426
1427

    def _validate_and_add_requests(
        self,
1428
        prompts: Union[PromptType, Sequence[PromptType]],
1429
1430
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1431
        *,
1432
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1433
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1434
        prompt_adapter_request: Optional[PromptAdapterRequest],
1435
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1436
        guided_options: Optional[GuidedDecodingRequest] = None,
1437
        priority: Optional[list[int]] = None,
1438
    ) -> None:
1439
1440
1441
1442
1443
1444
1445
1446
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1447
        if isinstance(prompts, (str, dict)):
1448
            # Convert a single prompt to a list.
1449
            prompts = [prompts]
1450

1451
        num_requests = len(prompts)
1452
1453
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1454
                             "must be the same.")
1455
1456
1457
1458
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1459

1460
1461
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1462
                self._add_guided_params(sp, guided_options)
1463
1464
1465

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1466

Zhuohan Li's avatar
Zhuohan Li committed
1467
        # Add requests to the engine.
1468
1469
        it = prompts
        if use_tqdm:
1470
1471
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1472
1473

        for i, prompt in enumerate(it):
1474
            self._add_request(
1475
                prompt,
1476
                params[i] if isinstance(params, Sequence) else params,
1477
                tokenization_kwargs=tokenization_kwargs,
1478
1479
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1480
                prompt_adapter_request=prompt_adapter_request,
1481
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1482
            )
1483

1484
    def _add_request(
nunjunj's avatar
nunjunj committed
1485
        self,
1486
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1487
        params: Union[SamplingParams, PoolingParams],
1488
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1489
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1490
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1491
        priority: int = 0,
1492
1493
    ) -> None:
        request_id = str(next(self.request_counter))
1494
1495
        self.llm_engine.add_request(
            request_id,
1496
            prompt,
1497
1498
            params,
            lora_request=lora_request,
1499
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1500
            prompt_adapter_request=prompt_adapter_request,
1501
            priority=priority,
nunjunj's avatar
nunjunj committed
1502
        )
1503

1504
    def _add_guided_params(
1505
1506
1507
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1508
1509
1510
1511
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1512
            raise ValueError("Cannot set both guided_options_request and "
1513
1514
1515
1516
1517
1518
1519
1520
1521
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1522
1523
1524
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1525
1526
        return params

1527
    def _run_engine(
1528
1529
1530
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1531
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1532
1533
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1534
            num_requests = self.llm_engine.get_num_unfinished_requests()
1535
1536
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1537
1538
1539
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1540
1541
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1542
            )
1543

Zhuohan Li's avatar
Zhuohan Li committed
1544
        # Run the engine.
1545
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1546
1547
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1548
1549
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1550
            for output in step_outputs:
1551
                if output.finished:
1552
1553
                    outputs.append(output)
                    if use_tqdm:
1554
1555
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1556
                            n = len(output.outputs)
1557
                            assert output.prompt_token_ids is not None
1558
                            total_in_toks += len(output.prompt_token_ids) * n
1559
1560
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1561
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1562
1563
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1564
1565
1566
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1567
                            pbar.update(n)
1568
1569
                        else:
                            pbar.update(1)
1570

1571
1572
        if use_tqdm:
            pbar.close()
1573
1574
1575
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1576
        return sorted(outputs, key=lambda x: int(x.request_id))