llm.py 70.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
18
19
                              BeamSearchSequence,
                              create_sort_beams_key_function)
20
21
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
22
23
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
24
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
25
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
26
                                         ChatTemplateContentFormatOption,
27
28
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
29
30
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
31
32
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
33
from vllm.entrypoints.utils import _validate_truncation_size
34
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
35
from vllm.inputs.parse import parse_and_batch_prompt
36
from vllm.logger import init_logger
37
from vllm.lora.request import LoRARequest
38
39
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
40
from vllm.model_executor.layers.quantization import QuantizationMethods
41
42
43
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
44
from vllm.pooling_params import PoolingParams
45
from vllm.prompt_adapter.request import PromptAdapterRequest
46
47
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
48
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
49
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
50
from vllm.usage.usage_lib import UsageContext
51
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
52

53
54
55
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

56
57
logger = init_logger(__name__)

58
59
_R = TypeVar("_R", default=Any)

60
61

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
66
67
68
69
70
71
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
72
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
73
74
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
75
76
77
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
78
79
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
80
81
82
83
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
91
        quantization: The method used to quantize the model weights. Currently,
92
            we support "awq", "gptq", and "fp8" (experimental).
93
94
95
96
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
97
98
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
99
100
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
101
102
103
104
105
106
107
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
108
109
110
111
112
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
113
114
115
116
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
117
118
119
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
120
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
121
            When a sequence has context length larger than this, we fall back
122
123
124
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
125
126
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
127
128
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
129
        hf_token: The token to use as HTTP bearer authorization for remote files
130
            . If `True`, will use the token generated when running
131
            `huggingface-cli login` (stored in `~/.huggingface`).
132
133
134
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
135
136
137
138
139
140
141
142
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
143
144
145
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
146
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
147

148
149
    Note:
        This class is intended to be used for offline inference. For online
150
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
151
    """
152

153
    DEPRECATE_LEGACY: ClassVar[bool] = True
154
155
156
157
158
159
160
161
162
163
164
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

165
166
167
    def __init__(
        self,
        model: str,
168
169
        *,
        task: TaskOption = "auto",
170
        tokenizer: Optional[str] = None,
171
        tokenizer_mode: TokenizerMode = "auto",
172
        skip_tokenizer_init: bool = False,
173
        trust_remote_code: bool = False,
174
        allowed_local_media_path: str = "",
175
        tensor_parallel_size: int = 1,
176
177
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
178
        revision: Optional[str] = None,
179
        tokenizer_revision: Optional[str] = None,
180
        seed: Optional[int] = None,
181
        gpu_memory_utilization: float = 0.9,
182
        swap_space: float = 4,
183
        cpu_offload_gb: float = 0,
184
        enforce_eager: bool = False,
185
        max_seq_len_to_capture: int = 8192,
186
        disable_custom_all_reduce: bool = False,
187
        disable_async_output_proc: bool = False,
188
        hf_token: Optional[Union[bool, str]] = None,
189
        hf_overrides: Optional[HfOverrides] = None,
190
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
191
        override_pooler_config: Optional[PoolerConfig] = None,
192
193
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
194
195
        **kwargs,
    ) -> None:
196
        """LLM constructor."""
197

198
199
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
200

201
202
203
204
205
206
207
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

225
226
227
        if hf_overrides is None:
            hf_overrides = {}

228
        if compilation_config is not None:
229
230
231
232
233
234
235
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
236
237
            else:
                compilation_config_instance = compilation_config
238
        else:
239
            compilation_config_instance = CompilationConfig()
240

Zhuohan Li's avatar
Zhuohan Li committed
241
        engine_args = EngineArgs(
242
            model=model,
243
            task=task,
244
            tokenizer=tokenizer,
245
            tokenizer_mode=tokenizer_mode,
246
            skip_tokenizer_init=skip_tokenizer_init,
247
            trust_remote_code=trust_remote_code,
248
            allowed_local_media_path=allowed_local_media_path,
249
250
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
251
            quantization=quantization,
252
            revision=revision,
253
            tokenizer_revision=tokenizer_revision,
254
255
256
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
257
            cpu_offload_gb=cpu_offload_gb,
258
            enforce_eager=enforce_eager,
259
            max_seq_len_to_capture=max_seq_len_to_capture,
260
            disable_custom_all_reduce=disable_custom_all_reduce,
261
            disable_async_output_proc=disable_async_output_proc,
262
            hf_token=hf_token,
263
            hf_overrides=hf_overrides,
264
            mm_processor_kwargs=mm_processor_kwargs,
265
            override_pooler_config=override_pooler_config,
266
            compilation_config=compilation_config_instance,
267
268
            **kwargs,
        )
269
270
271
272
273

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
274

275
        self.request_counter = Counter()
276
        self.default_sampling_params: Union[dict[str, Any], None] = None
277

278
279
280
281
282
283
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
284
285

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
286
        tokenizer_group = self.llm_engine.get_tokenizer_group()
287

288
289
290
291
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
292
            tokenizer_group.tokenizer = tokenizer
293
        else:
294
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
295

296
    def get_default_sampling_params(self) -> SamplingParams:
297
298
299
300
301
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
302
303
        return SamplingParams()

304
305
306
307
308
309
310
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
311
        *,
312
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
313
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
314
315
316
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
317
    ) -> list[RequestOutput]:
318
319
        ...

320
    @overload  # LEGACY: single (prompt + optional token ids)
321
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
322
323
324
325
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
326
327
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
328
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
329
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
330
331
332
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
333
    ) -> list[RequestOutput]:
334
335
336
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
337
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
338
339
    def generate(
        self,
340
        prompts: list[str],
341
        sampling_params: Optional[Union[SamplingParams,
342
343
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
344
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
345
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
346
347
348
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
349
    ) -> list[RequestOutput]:
350
351
352
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
353
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
354
355
356
357
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
358
                                        list[SamplingParams]]] = None,
359
        *,
360
        prompt_token_ids: list[int],
361
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
362
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
363
364
365
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
366
    ) -> list[RequestOutput]:
367
368
369
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
370
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
371
372
    def generate(
        self,
373
        prompts: Optional[list[str]] = None,
374
        sampling_params: Optional[Union[SamplingParams,
375
                                        list[SamplingParams]]] = None,
376
        *,
377
        prompt_token_ids: list[list[int]],
378
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
379
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
380
381
382
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
383
    ) -> list[RequestOutput]:
384
385
386
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
387
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
388
389
390
391
    def generate(
        self,
        prompts: None,
        sampling_params: None,
392
        prompt_token_ids: Union[list[int], list[list[int]]],
393
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
394
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
395
396
397
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
398
    ) -> list[RequestOutput]:
399
400
        ...

nunjunj's avatar
nunjunj committed
401
402
403
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
404
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
405
    )
406
407
    def generate(
        self,
408
        prompts: Union[Union[PromptType, Sequence[PromptType]],
409
                       Optional[Union[str, list[str]]]] = None,
410
411
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
412
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
413
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
414
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
415
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
416
        guided_options_request: Optional[Union[LLMGuidedOptions,
417
                                               GuidedDecodingRequest]] = None,
418
419
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
420
421
        """Generates the completions for the input prompts.

422
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
423
424
425
426
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
427
            prompts: The prompts to the LLM. You may pass a sequence of prompts
428
                for batch inference. See [PromptType][vllm.inputs.PromptType]
429
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
430
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
431
432
433
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
434
                prompts and it is paired one by one with the prompt.
435
436
437
438
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
439
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
440
            prompt_adapter_request: Prompt Adapter request to use for
441
                generation, if any.
442
443
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
444
445

        Returns:
446
            A list of `RequestOutput` objects containing the
447
            generated completions in the same order as the input prompts.
448

449
450
451
452
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
453
        """
454
        runner_type = self.llm_engine.model_config.runner_type
455
        if runner_type not in ["generate", "transcription"]:
456
            messages = [
457
                "LLM.generate() is only supported for (conditional) generation "
458
459
460
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

461
462
463
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
464
                messages.append(
465
466
467
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
468
469

            raise ValueError(" ".join(messages))
470

471
        if prompt_token_ids is not None:
472
            parsed_prompts = self._convert_v1_inputs(
473
                prompts=cast(Optional[Union[str, list[str]]], prompts),
474
475
476
                prompt_token_ids=prompt_token_ids,
            )
        else:
477
478
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
479

480
481
482
483
484
485
486
487
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

488
489
        if sampling_params is None:
            # Use default sampling params.
490
            sampling_params = self.get_default_sampling_params()
491

492
493
494
495
496
497
498
        tokenization_kwargs: dict[str, Any] = {}
        truncate_prompt_tokens = None
        if isinstance(sampling_params, SamplingParams):
            truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

499
        self._validate_and_add_requests(
500
            prompts=parsed_prompts,
501
            params=sampling_params,
502
            use_tqdm=use_tqdm,
503
            lora_request=lora_request,
504
            prompt_adapter_request=prompt_adapter_request,
505
            guided_options=guided_options_request,
506
            tokenization_kwargs=tokenization_kwargs,
507
508
            priority=priority,
        )
509

510
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
511
        return self.engine_class.validate_outputs(outputs, RequestOutput)
512

513
    def collective_rpc(self,
514
                       method: Union[str, Callable[..., _R]],
515
                       timeout: Optional[float] = None,
516
517
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
518
519
520
521
522
523
524
525
526
527
528
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
529
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
530
531
532
533
534
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
535

536
537
538
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
539
        """
540
541

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
542
543

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
544
        """
545
546
        Run a function directly on the model inside each worker,
        returning the result for each of them.
547
        """
548
549
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
550

551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

567
568
    def beam_search(
        self,
569
        prompts: list[Union[TokensPrompt, TextPrompt]],
570
        params: BeamSearchParams,
571
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
572
        use_tqdm: bool = False,
573
    ) -> list[BeamSearchOutput]:
574
575
576
577
578
579
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
580
            params: The beam search parameters.
581
            lora_request: LoRA request to use for generation, if any.
582
            use_tqdm: Whether to use tqdm to display the progress bar.
583
        """
584
585
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
586
587
588
589
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
590
591
        length_penalty = params.length_penalty

592
593
594
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

595
596
597
598
599
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
600

601
602
603
604
605
606
607
608
609
610
611
612
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
613

614
615
616
617
618
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
619
                                            temperature=temperature)
620
        instances: list[BeamSearchInstance] = []
621

622
        for lora_req, prompt in zip(lora_requests, prompts):
623
624
625
626
627
628
629
630
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

631
632
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
633
634
635
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
636

637
            instances.append(
638
639
640
641
642
643
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
644

645
646
647
648
649
650
651
652
653
654
655
656
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
657
            all_beams: list[BeamSearchSequence] = list(
658
659
660
661
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
662
            instance_start_and_end: list[tuple[int, int]] = list(
663
664
665
666
667
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

668
669
670
671
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
672
673
674
675
676

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
677
678
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
695
                                logprobs=current_beam.logprobs + [logprobs],
696
                                lora_request=current_beam.lora_request,
697
                                cum_logprob=current_beam.cum_logprob +
698
699
700
701
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
702
703
704
705
706
707
708

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
709
                                      key=sort_beams_key,
710
711
712
713
714
715
716
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
717
                                      key=sort_beams_key,
718
719
720
721
722
723
724
725
726
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
727
728
    def chat(
        self,
729
730
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
731
        sampling_params: Optional[Union[SamplingParams,
732
                                        list[SamplingParams]]] = None,
733
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
734
735
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
736
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
737
        add_generation_prompt: bool = True,
738
        continue_final_message: bool = False,
739
        tools: Optional[list[dict[str, Any]]] = None,
740
        chat_template_kwargs: Optional[dict[str, Any]] = None,
741
742
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
743
        """
744
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
745

746
        The chat conversation is converted into a text prompt using the
747
        tokenizer and calls the [generate][] method to generate the
748
749
750
751
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
752
753

        Args:
754
755
            messages: A list of conversations or a single conversation.

756
757
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
758

nunjunj's avatar
nunjunj committed
759
760
761
762
763
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
764
765
766
767
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
768
769
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
770
                If not provided, the model's default chat template will be used.
771
772
            chat_template_content_format: The format to render message content.

773
774
775
776
777
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
778

779
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
780
                to each message.
781
            continue_final_message: If True, continues the final message in
782
                the conversation instead of starting a new one. Cannot be
783
                `True` if `add_generation_prompt` is also `True`.
784
785
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
786
787
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
788
789

        Returns:
790
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
791
792
            responses in the same order as the input messages.
        """
793
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
794

795
796
        # Handle multi and single conversations
        if is_list_of(messages, list):
797
798
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
799
                                    messages)
800
        else:
801
            # messages is list[...]
802
            list_of_messages = [
803
                cast(list[ChatCompletionMessageParam], messages)
804
            ]
805

806
        tokenizer = self.get_tokenizer(lora_request)
807
808
809
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
810
            tools,
811
812
            chat_template_content_format,
            tokenizer,
813
            model_config=model_config,
814
815
        )

816
817
818
819
820
821
822
823
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

824
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
825
826

        for msgs in list_of_messages:
827
828
829
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
830
            conversation, mm_data = parse_chat_messages(
831
832
833
834
835
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
836
837

            if isinstance(tokenizer, MistralTokenizer):
838
                prompt_token_ids = apply_mistral_chat_template(
839
840
                    tokenizer,
                    messages=msgs,
841
                    **_chat_template_kwargs,
842
843
                )
            else:
844
                prompt_str = apply_hf_chat_template(
845
                    tokenizer=tokenizer,
846
                    conversation=conversation,
847
                    model_config=model_config,
848
                    **_chat_template_kwargs,
849
                )
850
851
852
853
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
854

855
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
856
857
858
859

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

860
861
862
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

863
            prompts.append(prompt)
864

nunjunj's avatar
nunjunj committed
865
        return self.generate(
866
            prompts,
867
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
868
869
870
871
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

872
873
874
875
876
877
878
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
879
        *,
880
        truncate_prompt_tokens: Optional[int] = None,
881
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
882
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
883
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
884
    ) -> list[PoolingRequestOutput]:
885
886
        ...

887
    @overload  # LEGACY: single (prompt + optional token ids)
888
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
889
890
891
892
893
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
894
        prompt_token_ids: Optional[list[int]] = None,
895
        truncate_prompt_tokens: Optional[int] = None,
896
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
897
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
898
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
899
    ) -> list[PoolingRequestOutput]:
900
        ...
901

902
    @overload  # LEGACY: multi (prompt + optional token ids)
903
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
904
905
    def encode(
        self,
906
        prompts: list[str],
907
        pooling_params: Optional[Union[PoolingParams,
908
                                       Sequence[PoolingParams]]] = None,
909
        prompt_token_ids: Optional[list[list[int]]] = None,
910
        truncate_prompt_tokens: Optional[int] = None,
911
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
912
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
913
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
914
    ) -> list[PoolingRequestOutput]:
915
916
917
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
918
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
919
920
921
922
923
924
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
925
        prompt_token_ids: list[int],
926
        truncate_prompt_tokens: Optional[int] = None,
927
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
928
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
929
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
930
    ) -> list[PoolingRequestOutput]:
931
932
933
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
934
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
935
936
    def encode(
        self,
937
        prompts: Optional[list[str]] = None,
938
939
940
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
941
        prompt_token_ids: list[list[int]],
942
        truncate_prompt_tokens: Optional[int] = None,
943
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
944
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
945
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
946
    ) -> list[PoolingRequestOutput]:
947
948
949
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
950
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
951
952
953
954
    def encode(
        self,
        prompts: None,
        pooling_params: None,
955
        prompt_token_ids: Union[list[int], list[list[int]]],
956
        truncate_prompt_tokens: Optional[int] = None,
957
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
958
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
959
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
960
    ) -> list[PoolingRequestOutput]:
961
962
        ...

nunjunj's avatar
nunjunj committed
963
964
965
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
966
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
967
    )
968
969
    def encode(
        self,
970
        prompts: Union[Union[PromptType, Sequence[PromptType]],
971
                       Optional[Union[str, list[str]]]] = None,
972
973
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
974
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
975
        truncate_prompt_tokens: Optional[int] = None,
976
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
977
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
978
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
979
    ) -> list[PoolingRequestOutput]:
980
981
        """Apply pooling to the hidden states corresponding to the input
        prompts.
982

983
        This class automatically batches the given prompts, considering
984
985
986
987
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
988
            prompts: The prompts to the LLM. You may pass a sequence of prompts
989
                for batch inference. See [PromptType][vllm.inputs.PromptType]
990
                for more details about the format of each prompts.
991
992
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
993
994
995
996
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
997
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
998
            prompt_adapter_request: Prompt Adapter request to use for
999
                generation, if any.
1000
1001

        Returns:
1002
            A list of `PoolingRequestOutput` objects containing the
1003
            pooled hidden states in the same order as the input prompts.
1004

1005
1006
1007
1008
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1009
        """
1010
1011
1012
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
1013

1014
1015
1016
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1017
                messages.append(
1018
1019
1020
1021
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1022
1023

            raise ValueError(" ".join(messages))
1024

1025
        if prompt_token_ids is not None:
1026
            parsed_prompts = self._convert_v1_inputs(
1027
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1028
1029
1030
                prompt_token_ids=prompt_token_ids,
            )
        else:
1031
1032
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1033

1034
1035
1036
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1037
1038
1039
1040
1041
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
1042

1043
1044
1045
1046
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

1047
        self._validate_and_add_requests(
1048
            prompts=parsed_prompts,
1049
            params=pooling_params,
1050
            use_tqdm=use_tqdm,
1051
            lora_request=lora_request,
1052
            tokenization_kwargs=tokenization_kwargs,
1053
            prompt_adapter_request=prompt_adapter_request,
1054
1055
        )

1056
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1057
        return self.engine_class.validate_outputs(outputs,
1058
                                                  PoolingRequestOutput)
1059

1060
1061
1062
1063
1064
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1065
        truncate_prompt_tokens: Optional[int] = None,
1066
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1067
1068
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1069
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1070
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1071
    ) -> list[EmbeddingRequestOutput]:
1072
1073
1074
1075
1076
1077
1078
1079
1080
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1081
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1082
                for more details about the format of each prompts.
1083
1084
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1085
1086
1087
1088
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1089
1090
1091
1092
1093
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1094
            A list of `EmbeddingRequestOutput` objects containing the
1095
1096
1097
1098
1099
1100
1101
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1102
                            truncate_prompt_tokens=truncate_prompt_tokens,
1103
                            use_tqdm=use_tqdm,
1104
                            pooling_params=pooling_params,
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1115
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1116
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1117
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1118
    ) -> list[ClassificationRequestOutput]:
1119
1120
1121
1122
1123
1124
1125
1126
1127
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1128
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1129
                for more details about the format of each prompts.
1130
1131
1132
1133
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1134
1135
1136
1137
1138
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1139
            A list of `ClassificationRequestOutput` objects containing the
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1153
1154
1155
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1156
1157
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1158
        truncate_prompt_tokens: Optional[int] = None,
1159
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1160
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1161
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1162
    ) -> list[ScoringRequestOutput]:
1163

1164
        encoded_output: list[PoolingRequestOutput] = self.encode(
1165
            text_1 + text_2,
1166
            truncate_prompt_tokens=truncate_prompt_tokens,
1167
1168
1169
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1170

1171
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1172
            0:len(text_1)]
1173
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1174
            len(text_1):]
1175
1176
1177
1178

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1179
1180
1181
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1182
1183
1184
1185
1186
1187
1188

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1189
        tokenizer: AnyTokenizer,
1190
1191
        text_1: list[str],
        text_2: list[str],
1192
        truncate_prompt_tokens: Optional[int] = None,
1193
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1194
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1195
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1196
    ) -> list[ScoringRequestOutput]:
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1209
        tokenization_kwargs: dict[str, Any] = {}
1210
1211
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1227
            use_tqdm=use_tqdm,
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1238
1239
1240
1241
1242
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1243
        *,
1244
        truncate_prompt_tokens: Optional[int] = None,
1245
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1246
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1247
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1248
    ) -> list[ScoringRequestOutput]:
1249
        """Generate similarity scores for all pairs `<text,text_pair>`.
1250

1251
1252
1253
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1254
        The input pairs are used to build a list of prompts for the
1255
1256
1257
1258
1259
1260
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1261
                case it has to have the same length as the `text_2` list
1262
            text_2: The texts to pair with the query to form the input
1263
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1264
                more details about the format of each prompts.
1265
1266
1267
1268
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1269
1270
1271
1272
1273
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1274
            A list of `ScoringRequestOutput` objects containing the
1275
1276
            generated scores in the same order as the input prompts.
        """
1277
1278
1279
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1280

1281
1282
1283
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1284
                messages.append(
1285
1286
1287
1288
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1289
1290
1291

            raise ValueError(" ".join(messages))

1292
1293
1294
1295
1296
1297
1298
        if self.llm_engine.model_config.task not in ("embed", "classify"):
            raise ValueError("Score API is only enabled for "
                             "`--task embed or --task classify`.")

        if (self.llm_engine.model_config.task == "classify"
                and self.llm_engine.model_config.hf_config.num_labels != 1):
            raise ValueError("Score API is only enabled for num_labels == 1.")
1299
1300
1301
1302

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1303
        tokenizer = self.get_tokenizer()
1304

1305
1306
1307
1308
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1309
                                     "supported for scoring")
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1321
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1322
1323
1324
1325

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1326
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1327

1328
        _validate_score_input_lens(input_text_1, input_text_2)
1329

1330
        if self.llm_engine.model_config.is_cross_encoder:
1331
1332
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1333
1334
1335
1336
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1337
1338
1339
1340
1341
1342
1343
1344
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1345

1346
1347
1348
1349
1350
1351
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1352
1353
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1354

1355
1356
1357
1358
1359
1360
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1361
        Args:
1362
1363
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1364
                is forgotten. Level 1 sleep is good for sleeping and waking
1365
1366
1367
1368
1369
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1370
                sleep is good for sleeping and waking up the engine to run a
1371
                different model or update the model, where previous model
1372
                weights are not needed. It reduces CPU memory pressure.
1373
        """
1374
        self.reset_prefix_cache()
1375
1376
        self.llm_engine.sleep(level=level)

1377
    def wake_up(self, tags: Optional[list[str]] = None):
1378
        """
1379
        Wake up the engine from sleep mode. See the [sleep][] method
1380
        for more details.
1381

1382
        Args:
1383
1384
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1385
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1386
                wake_up should be called with all tags (or None) before the
1387
1388
1389
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1390

1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1405
1406
    # LEGACY
    def _convert_v1_inputs(
1407
        self,
1408
1409
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1410
1411
    ):
        # skip_tokenizer_init is now checked in engine
1412

1413
1414
1415
1416
1417
1418
1419
1420
1421
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1422
1423
1424
1425
1426
1427
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1428
1429
        if prompts is not None:
            num_requests = len(prompts)
1430
        elif prompt_token_ids is not None:
1431
            num_requests = len(prompt_token_ids)
1432
        parsed_prompts: list[PromptType] = []
1433
        for i in range(num_requests):
1434
            item: PromptType
1435

1436
            if prompts is not None:
1437
1438
1439
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1440
            else:
1441
                raise AssertionError
1442

1443
            parsed_prompts.append(item)
1444

1445
        return parsed_prompts
1446
1447
1448

    def _validate_and_add_requests(
        self,
1449
        prompts: Union[PromptType, Sequence[PromptType]],
1450
1451
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1452
        *,
1453
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1454
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1455
        prompt_adapter_request: Optional[PromptAdapterRequest],
1456
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1457
        guided_options: Optional[GuidedDecodingRequest] = None,
1458
        priority: Optional[list[int]] = None,
1459
    ) -> None:
1460
1461
1462
1463
1464
1465
1466
1467
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1468
        if isinstance(prompts, (str, dict)):
1469
            # Convert a single prompt to a list.
1470
            prompts = [prompts]
1471

1472
        num_requests = len(prompts)
1473
        if isinstance(params, Sequence) and len(params) != num_requests:
1474
            raise ValueError("The lengths of prompts and params "
1475
                             "must be the same.")
1476
        if isinstance(lora_request,
1477
                      Sequence) and len(lora_request) != num_requests:
1478
1479
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1480

1481
        for sp in params if isinstance(params, Sequence) else (params, ):
1482
            if isinstance(sp, SamplingParams):
1483
                self._add_guided_params(sp, guided_options)
1484
1485
1486

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1487

Zhuohan Li's avatar
Zhuohan Li committed
1488
        # Add requests to the engine.
1489
1490
        it = prompts
        if use_tqdm:
1491
1492
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1493
1494

        for i, prompt in enumerate(it):
1495
            self._add_request(
1496
                prompt,
1497
                params[i] if isinstance(params, Sequence) else params,
1498
                tokenization_kwargs=tokenization_kwargs,
1499
1500
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1501
                prompt_adapter_request=prompt_adapter_request,
1502
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1503
            )
1504

1505
    def _add_request(
nunjunj's avatar
nunjunj committed
1506
        self,
1507
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1508
        params: Union[SamplingParams, PoolingParams],
1509
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1510
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1511
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1512
        priority: int = 0,
1513
1514
    ) -> None:
        request_id = str(next(self.request_counter))
1515
1516
        self.llm_engine.add_request(
            request_id,
1517
            prompt,
1518
1519
            params,
            lora_request=lora_request,
1520
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1521
            prompt_adapter_request=prompt_adapter_request,
1522
            priority=priority,
nunjunj's avatar
nunjunj committed
1523
        )
1524

1525
    def _add_guided_params(
1526
1527
1528
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1529
1530
1531
1532
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1533
            raise ValueError("Cannot set both guided_options_request and "
1534
1535
1536
1537
1538
1539
1540
1541
1542
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1543
1544
1545
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1546
1547
        return params

1548
    def _run_engine(
1549
1550
1551
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1552
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1553
1554
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1555
            num_requests = self.llm_engine.get_num_unfinished_requests()
1556
1557
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1558
1559
1560
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1561
1562
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1563
            )
1564

Zhuohan Li's avatar
Zhuohan Li committed
1565
        # Run the engine.
1566
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1567
1568
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1569
1570
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1571
            for output in step_outputs:
1572
                if output.finished:
1573
1574
                    outputs.append(output)
                    if use_tqdm:
1575
1576
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1577
                            n = len(output.outputs)
1578
                            assert output.prompt_token_ids is not None
1579
                            total_in_toks += len(output.prompt_token_ids) * n
1580
1581
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1582
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1583
1584
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1585
1586
1587
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1588
                            pbar.update(n)
1589
1590
                        else:
                            pbar.update(1)
1591
1592
                        if pbar.n == num_requests:
                            pbar.refresh()
1593

1594
1595
        if use_tqdm:
            pbar.close()
1596
1597
1598
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1599
        return sorted(outputs, key=lambda x: int(x.request_id))