llm.py 69 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
18
19
                              BeamSearchSequence,
                              create_sort_beams_key_function)
20
21
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
22
23
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
24
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
25
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
26
                                         ChatTemplateContentFormatOption,
27
28
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
29
30
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
31
32
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
33
from vllm.entrypoints.utils import _validate_truncation_size
34
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
35
from vllm.inputs.parse import parse_and_batch_prompt
36
from vllm.logger import init_logger
37
from vllm.lora.request import LoRARequest
38
39
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
40
from vllm.model_executor.layers.quantization import QuantizationMethods
41
42
43
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
44
from vllm.pooling_params import PoolingParams
45
from vllm.prompt_adapter.request import PromptAdapterRequest
46
47
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
48
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
49
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
50
from vllm.usage.usage_lib import UsageContext
51
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
52

53
54
55
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

56
57
logger = init_logger(__name__)

58
59
_R = TypeVar("_R", default=Any)

60
61

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
66
67
68
69
70
71
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
72
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
73
74
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
75
76
77
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
78
79
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
80
81
82
83
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
91
        quantization: The method used to quantize the model weights. Currently,
92
            we support "awq", "gptq", and "fp8" (experimental).
93
94
95
96
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
97
98
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
99
100
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
101
102
103
104
105
106
107
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
108
109
110
111
112
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
113
114
115
116
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
117
118
119
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
120
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
121
            When a sequence has context length larger than this, we fall back
122
123
124
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
125
126
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
127
128
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
129
        hf_token: The token to use as HTTP bearer authorization for remote files
130
            . If `True`, will use the token generated when running
131
            `huggingface-cli login` (stored in `~/.huggingface`).
132
133
134
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
135
136
137
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
138
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
139

140
141
    Note:
        This class is intended to be used for offline inference. For online
142
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
143
    """
144

145
    DEPRECATE_LEGACY: ClassVar[bool] = True
146
147
148
149
150
151
152
153
154
155
156
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

157
158
159
    def __init__(
        self,
        model: str,
160
161
        *,
        task: TaskOption = "auto",
162
        tokenizer: Optional[str] = None,
163
        tokenizer_mode: TokenizerMode = "auto",
164
        skip_tokenizer_init: bool = False,
165
        trust_remote_code: bool = False,
166
        allowed_local_media_path: str = "",
167
        tensor_parallel_size: int = 1,
168
169
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
170
        revision: Optional[str] = None,
171
        tokenizer_revision: Optional[str] = None,
172
        seed: Optional[int] = None,
173
        gpu_memory_utilization: float = 0.9,
174
        swap_space: float = 4,
175
        cpu_offload_gb: float = 0,
176
        enforce_eager: bool = False,
177
        max_seq_len_to_capture: int = 8192,
178
        disable_custom_all_reduce: bool = False,
179
        disable_async_output_proc: bool = False,
180
        hf_token: Optional[Union[bool, str]] = None,
181
        hf_overrides: Optional[HfOverrides] = None,
182
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
183
        override_pooler_config: Optional[PoolerConfig] = None,
184
185
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
186
187
        **kwargs,
    ) -> None:
188
        """LLM constructor."""
189

190
191
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
192

193
194
195
196
197
198
199
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

217
218
219
        if hf_overrides is None:
            hf_overrides = {}

220
        if compilation_config is not None:
221
222
223
224
225
226
227
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
228
229
            else:
                compilation_config_instance = compilation_config
230
        else:
231
            compilation_config_instance = CompilationConfig()
232

Zhuohan Li's avatar
Zhuohan Li committed
233
        engine_args = EngineArgs(
234
            model=model,
235
            task=task,
236
            tokenizer=tokenizer,
237
            tokenizer_mode=tokenizer_mode,
238
            skip_tokenizer_init=skip_tokenizer_init,
239
            trust_remote_code=trust_remote_code,
240
            allowed_local_media_path=allowed_local_media_path,
241
242
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
243
            quantization=quantization,
244
            revision=revision,
245
            tokenizer_revision=tokenizer_revision,
246
247
248
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
249
            cpu_offload_gb=cpu_offload_gb,
250
            enforce_eager=enforce_eager,
251
            max_seq_len_to_capture=max_seq_len_to_capture,
252
            disable_custom_all_reduce=disable_custom_all_reduce,
253
            disable_async_output_proc=disable_async_output_proc,
254
            hf_token=hf_token,
255
            hf_overrides=hf_overrides,
256
            mm_processor_kwargs=mm_processor_kwargs,
257
            override_pooler_config=override_pooler_config,
258
            compilation_config=compilation_config_instance,
259
260
            **kwargs,
        )
261
262
263
264
265

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
266

267
        self.request_counter = Counter()
268
        self.default_sampling_params: Union[dict[str, Any], None] = None
269

270
271
272
273
274
275
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
276
277

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
278
        tokenizer_group = self.llm_engine.get_tokenizer_group()
279

280
281
282
283
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
284
            tokenizer_group.tokenizer = tokenizer
285
        else:
286
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
287

288
    def get_default_sampling_params(self) -> SamplingParams:
289
290
291
292
293
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
294
295
        return SamplingParams()

296
297
298
299
300
301
302
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
303
        *,
304
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
305
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
306
307
308
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
309
    ) -> list[RequestOutput]:
310
311
        ...

312
    @overload  # LEGACY: single (prompt + optional token ids)
313
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
314
315
316
317
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
318
319
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
320
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
321
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
322
323
324
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
325
    ) -> list[RequestOutput]:
326
327
328
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
329
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
330
331
    def generate(
        self,
332
        prompts: list[str],
333
        sampling_params: Optional[Union[SamplingParams,
334
335
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
336
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
337
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
338
339
340
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
341
    ) -> list[RequestOutput]:
342
343
344
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
345
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
346
347
348
349
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
350
                                        list[SamplingParams]]] = None,
351
        *,
352
        prompt_token_ids: list[int],
353
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
354
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
355
356
357
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
358
    ) -> list[RequestOutput]:
359
360
361
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
362
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
363
364
    def generate(
        self,
365
        prompts: Optional[list[str]] = None,
366
        sampling_params: Optional[Union[SamplingParams,
367
                                        list[SamplingParams]]] = None,
368
        *,
369
        prompt_token_ids: list[list[int]],
370
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
371
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
372
373
374
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
375
    ) -> list[RequestOutput]:
376
377
378
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
379
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
380
381
382
383
    def generate(
        self,
        prompts: None,
        sampling_params: None,
384
        prompt_token_ids: Union[list[int], list[list[int]]],
385
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
386
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
387
388
389
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
390
    ) -> list[RequestOutput]:
391
392
        ...

nunjunj's avatar
nunjunj committed
393
394
395
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
396
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
397
    )
398
399
    def generate(
        self,
400
        prompts: Union[Union[PromptType, Sequence[PromptType]],
401
                       Optional[Union[str, list[str]]]] = None,
402
403
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
404
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
405
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
406
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
407
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
408
        guided_options_request: Optional[Union[LLMGuidedOptions,
409
                                               GuidedDecodingRequest]] = None,
410
411
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
        """Generates the completions for the input prompts.

414
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
418
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
419
            prompts: The prompts to the LLM. You may pass a sequence of prompts
420
                for batch inference. See [PromptType][vllm.inputs.PromptType]
421
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
422
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
423
424
425
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
426
                prompts and it is paired one by one with the prompt.
427
428
429
430
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
431
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
432
            prompt_adapter_request: Prompt Adapter request to use for
433
                generation, if any.
434
435
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
436
437

        Returns:
438
            A list of `RequestOutput` objects containing the
439
            generated completions in the same order as the input prompts.
440

441
442
443
444
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
445
        """
446
        runner_type = self.llm_engine.model_config.runner_type
447
        if runner_type not in ["generate", "transcription"]:
448
            messages = [
449
                "LLM.generate() is only supported for (conditional) generation "
450
451
452
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

453
454
455
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
456
                messages.append(
457
458
459
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
460
461

            raise ValueError(" ".join(messages))
462

463
        if prompt_token_ids is not None:
464
            parsed_prompts = self._convert_v1_inputs(
465
                prompts=cast(Optional[Union[str, list[str]]], prompts),
466
467
468
                prompt_token_ids=prompt_token_ids,
            )
        else:
469
470
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
471

472
473
474
475
476
477
478
479
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

480
481
        if sampling_params is None:
            # Use default sampling params.
482
            sampling_params = self.get_default_sampling_params()
483

484
        self._validate_and_add_requests(
485
            prompts=parsed_prompts,
486
            params=sampling_params,
487
            use_tqdm=use_tqdm,
488
            lora_request=lora_request,
489
            prompt_adapter_request=prompt_adapter_request,
490
            guided_options=guided_options_request,
491
492
            priority=priority,
        )
493

494
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
495
        return self.engine_class.validate_outputs(outputs, RequestOutput)
496

497
    def collective_rpc(self,
498
                       method: Union[str, Callable[..., _R]],
499
                       timeout: Optional[float] = None,
500
501
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
502
503
504
505
506
507
508
509
510
511
512
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
513
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
514
515
516
517
518
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
519

520
521
522
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
523
        """
524
525

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
526
527

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
528
        """
529
530
        Run a function directly on the model inside each worker,
        returning the result for each of them.
531
        """
532
533
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
534

535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

551
552
    def beam_search(
        self,
553
        prompts: list[Union[TokensPrompt, TextPrompt]],
554
        params: BeamSearchParams,
555
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
556
        use_tqdm: bool = False,
557
    ) -> list[BeamSearchOutput]:
558
559
560
561
562
563
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
564
            params: The beam search parameters.
565
            lora_request: LoRA request to use for generation, if any.
566
            use_tqdm: Whether to use tqdm to display the progress bar.
567
        """
568
569
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
570
571
572
573
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
574
575
        length_penalty = params.length_penalty

576
577
578
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

579
580
581
582
583
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
584

585
586
587
588
589
590
591
592
593
594
595
596
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
597

598
599
600
601
602
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
603
                                            temperature=temperature)
604
        instances: list[BeamSearchInstance] = []
605

606
        for lora_req, prompt in zip(lora_requests, prompts):
607
608
609
610
611
612
613
614
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

615
616
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
617
618
619
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
620

621
            instances.append(
622
623
624
625
626
627
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
628

629
630
631
632
633
634
635
636
637
638
639
640
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
641
            all_beams: list[BeamSearchSequence] = list(
642
643
644
645
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
646
            instance_start_and_end: list[tuple[int, int]] = list(
647
648
649
650
651
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

652
653
654
655
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
656
657
658
659
660

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
661
662
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
679
                                logprobs=current_beam.logprobs + [logprobs],
680
                                lora_request=current_beam.lora_request,
681
                                cum_logprob=current_beam.cum_logprob +
682
683
684
685
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
686
687
688
689
690
691
692

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
693
                                      key=sort_beams_key,
694
695
696
697
698
699
700
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
701
                                      key=sort_beams_key,
702
703
704
705
706
707
708
709
710
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
711
712
    def chat(
        self,
713
714
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
715
        sampling_params: Optional[Union[SamplingParams,
716
                                        list[SamplingParams]]] = None,
717
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
718
719
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
720
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
721
        add_generation_prompt: bool = True,
722
        continue_final_message: bool = False,
723
        tools: Optional[list[dict[str, Any]]] = None,
724
        chat_template_kwargs: Optional[dict[str, Any]] = None,
725
726
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
727
        """
728
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
729

730
        The chat conversation is converted into a text prompt using the
731
        tokenizer and calls the [generate][] method to generate the
732
733
734
735
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
736
737

        Args:
738
739
            messages: A list of conversations or a single conversation.

740
741
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
742

nunjunj's avatar
nunjunj committed
743
744
745
746
747
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
748
749
750
751
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
752
753
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
754
                If not provided, the model's default chat template will be used.
755
756
            chat_template_content_format: The format to render message content.

757
758
759
760
761
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
762

763
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
764
                to each message.
765
            continue_final_message: If True, continues the final message in
766
                the conversation instead of starting a new one. Cannot be
767
                `True` if `add_generation_prompt` is also `True`.
768
769
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
770
771
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
772
773

        Returns:
774
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
775
776
            responses in the same order as the input messages.
        """
777
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
778

779
780
        # Handle multi and single conversations
        if is_list_of(messages, list):
781
782
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
783
                                    messages)
784
        else:
785
            # messages is list[...]
786
            list_of_messages = [
787
                cast(list[ChatCompletionMessageParam], messages)
788
            ]
789

790
        tokenizer = self.get_tokenizer(lora_request)
791
792
793
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
794
            tools,
795
796
            chat_template_content_format,
            tokenizer,
797
            model_config=model_config,
798
799
        )

800
801
802
803
804
805
806
807
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

808
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
809
810

        for msgs in list_of_messages:
811
812
813
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
814
            conversation, mm_data = parse_chat_messages(
815
816
817
818
819
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
820
821

            if isinstance(tokenizer, MistralTokenizer):
822
                prompt_token_ids = apply_mistral_chat_template(
823
824
                    tokenizer,
                    messages=msgs,
825
                    **_chat_template_kwargs,
826
827
                )
            else:
828
                prompt_str = apply_hf_chat_template(
829
                    tokenizer=tokenizer,
830
                    conversation=conversation,
831
                    model_config=model_config,
832
                    **_chat_template_kwargs,
833
                )
834
835
836
837
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
838

839
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
840
841
842
843

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

844
845
846
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

847
            prompts.append(prompt)
848

nunjunj's avatar
nunjunj committed
849
        return self.generate(
850
            prompts,
851
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
852
853
854
855
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

856
857
858
859
860
861
862
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
863
        *,
864
        truncate_prompt_tokens: Optional[int] = None,
865
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
866
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
867
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
868
    ) -> list[PoolingRequestOutput]:
869
870
        ...

871
    @overload  # LEGACY: single (prompt + optional token ids)
872
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
873
874
875
876
877
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
878
        prompt_token_ids: Optional[list[int]] = None,
879
        truncate_prompt_tokens: Optional[int] = None,
880
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
881
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
882
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
883
    ) -> list[PoolingRequestOutput]:
884
        ...
885

886
    @overload  # LEGACY: multi (prompt + optional token ids)
887
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
888
889
    def encode(
        self,
890
        prompts: list[str],
891
        pooling_params: Optional[Union[PoolingParams,
892
                                       Sequence[PoolingParams]]] = None,
893
        prompt_token_ids: Optional[list[list[int]]] = None,
894
        truncate_prompt_tokens: Optional[int] = None,
895
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
896
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
897
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
898
    ) -> list[PoolingRequestOutput]:
899
900
901
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
902
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
903
904
905
906
907
908
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
909
        prompt_token_ids: list[int],
910
        truncate_prompt_tokens: Optional[int] = None,
911
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
912
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
913
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
914
    ) -> list[PoolingRequestOutput]:
915
916
917
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
918
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
919
920
    def encode(
        self,
921
        prompts: Optional[list[str]] = None,
922
923
924
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
925
        prompt_token_ids: list[list[int]],
926
        truncate_prompt_tokens: Optional[int] = None,
927
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
928
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
929
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
930
    ) -> list[PoolingRequestOutput]:
931
932
933
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
934
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
935
936
937
938
    def encode(
        self,
        prompts: None,
        pooling_params: None,
939
        prompt_token_ids: Union[list[int], list[list[int]]],
940
        truncate_prompt_tokens: Optional[int] = None,
941
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
942
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
943
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
944
    ) -> list[PoolingRequestOutput]:
945
946
        ...

nunjunj's avatar
nunjunj committed
947
948
949
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
950
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
951
    )
952
953
    def encode(
        self,
954
        prompts: Union[Union[PromptType, Sequence[PromptType]],
955
                       Optional[Union[str, list[str]]]] = None,
956
957
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
958
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
959
        truncate_prompt_tokens: Optional[int] = None,
960
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
961
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
962
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
963
    ) -> list[PoolingRequestOutput]:
964
965
        """Apply pooling to the hidden states corresponding to the input
        prompts.
966

967
        This class automatically batches the given prompts, considering
968
969
970
971
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
972
            prompts: The prompts to the LLM. You may pass a sequence of prompts
973
                for batch inference. See [PromptType][vllm.inputs.PromptType]
974
                for more details about the format of each prompts.
975
976
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
977
978
979
980
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
981
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
982
            prompt_adapter_request: Prompt Adapter request to use for
983
                generation, if any.
984
985

        Returns:
986
            A list of `PoolingRequestOutput` objects containing the
987
            pooled hidden states in the same order as the input prompts.
988

989
990
991
992
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
993
        """
994
995
996
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
997

998
999
1000
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1001
                messages.append(
1002
1003
1004
1005
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1006
1007

            raise ValueError(" ".join(messages))
1008

1009
        if prompt_token_ids is not None:
1010
            parsed_prompts = self._convert_v1_inputs(
1011
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1012
1013
1014
                prompt_token_ids=prompt_token_ids,
            )
        else:
1015
1016
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1017

1018
1019
1020
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1021
1022
1023
1024
1025
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
1026

1027
1028
1029
1030
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

1031
        self._validate_and_add_requests(
1032
            prompts=parsed_prompts,
1033
            params=pooling_params,
1034
            use_tqdm=use_tqdm,
1035
            lora_request=lora_request,
1036
            tokenization_kwargs=tokenization_kwargs,
1037
            prompt_adapter_request=prompt_adapter_request,
1038
1039
        )

1040
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1041
        return self.engine_class.validate_outputs(outputs,
1042
                                                  PoolingRequestOutput)
1043

1044
1045
1046
1047
1048
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1049
        truncate_prompt_tokens: Optional[int] = None,
1050
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1051
1052
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1053
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1054
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1055
    ) -> list[EmbeddingRequestOutput]:
1056
1057
1058
1059
1060
1061
1062
1063
1064
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1065
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1066
                for more details about the format of each prompts.
1067
1068
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1069
1070
1071
1072
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1073
1074
1075
1076
1077
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1078
            A list of `EmbeddingRequestOutput` objects containing the
1079
1080
1081
1082
1083
1084
1085
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1086
                            truncate_prompt_tokens=truncate_prompt_tokens,
1087
                            use_tqdm=use_tqdm,
1088
                            pooling_params=pooling_params,
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1099
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1100
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1101
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1102
    ) -> list[ClassificationRequestOutput]:
1103
1104
1105
1106
1107
1108
1109
1110
1111
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1112
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1113
                for more details about the format of each prompts.
1114
1115
1116
1117
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1118
1119
1120
1121
1122
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1123
            A list of `ClassificationRequestOutput` objects containing the
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1137
1138
1139
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1140
1141
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1142
        truncate_prompt_tokens: Optional[int] = None,
1143
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1144
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1145
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1146
    ) -> list[ScoringRequestOutput]:
1147

1148
        encoded_output: list[PoolingRequestOutput] = self.encode(
1149
            text_1 + text_2,
1150
            truncate_prompt_tokens=truncate_prompt_tokens,
1151
1152
1153
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1154

1155
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1156
            0:len(text_1)]
1157
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1158
            len(text_1):]
1159
1160
1161
1162

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1163
1164
1165
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1166
1167
1168
1169
1170
1171
1172

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1173
        tokenizer: AnyTokenizer,
1174
1175
        text_1: list[str],
        text_2: list[str],
1176
        truncate_prompt_tokens: Optional[int] = None,
1177
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1178
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1179
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1180
    ) -> list[ScoringRequestOutput]:
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1193
        tokenization_kwargs: dict[str, Any] = {}
1194
1195
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1211
            use_tqdm=use_tqdm,
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1222
1223
1224
1225
1226
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1227
        *,
1228
        truncate_prompt_tokens: Optional[int] = None,
1229
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1230
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1231
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1232
    ) -> list[ScoringRequestOutput]:
1233
        """Generate similarity scores for all pairs `<text,text_pair>`.
1234

1235
1236
1237
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1238
        The input pairs are used to build a list of prompts for the
1239
1240
1241
1242
1243
1244
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1245
                case it has to have the same length as the `text_2` list
1246
            text_2: The texts to pair with the query to form the input
1247
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1248
                more details about the format of each prompts.
1249
1250
1251
1252
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1253
1254
1255
1256
1257
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1258
            A list of `ScoringRequestOutput` objects containing the
1259
1260
            generated scores in the same order as the input prompts.
        """
1261
1262
1263
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1264

1265
1266
1267
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1268
                messages.append(
1269
1270
1271
1272
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1273
1274
1275

            raise ValueError(" ".join(messages))

1276
        if self.llm_engine.model_config.task not in ("embed", "score"):
1277
            raise ValueError(
1278
                "Score API is only enabled for `--task embed or --task score`")
1279
1280
1281
1282

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1283
        tokenizer = self.get_tokenizer()
1284

1285
1286
1287
1288
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1289
                                     "supported for scoring")
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1301
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1302
1303
1304
1305

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1306
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1307

1308
        _validate_score_input_lens(input_text_1, input_text_2)
1309

1310
        if self.llm_engine.model_config.is_cross_encoder:
1311
1312
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1313
1314
1315
1316
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1317
1318
1319
1320
1321
1322
1323
1324
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1325

1326
1327
1328
1329
1330
1331
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1332
1333
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1334

1335
1336
1337
1338
1339
1340
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1353
        """
1354
        self.reset_prefix_cache()
1355
1356
        self.llm_engine.sleep(level=level)

1357
    def wake_up(self, tags: Optional[list[str]] = None):
1358
        """
1359
        Wake up the engine from sleep mode. See the [sleep][] method
1360
1361
1362
1363
1364
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1365
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1366
1367
1368
1369
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1370

1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1385
1386
    # LEGACY
    def _convert_v1_inputs(
1387
        self,
1388
1389
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1390
1391
    ):
        # skip_tokenizer_init is now checked in engine
1392

1393
1394
1395
1396
1397
1398
1399
1400
1401
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1402
1403
1404
1405
1406
1407
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1408
1409
        if prompts is not None:
            num_requests = len(prompts)
1410
        elif prompt_token_ids is not None:
1411
            num_requests = len(prompt_token_ids)
1412
        parsed_prompts: list[PromptType] = []
1413
        for i in range(num_requests):
1414
            item: PromptType
1415

1416
            if prompts is not None:
1417
1418
1419
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1420
            else:
1421
                raise AssertionError
1422

1423
            parsed_prompts.append(item)
1424

1425
        return parsed_prompts
1426
1427
1428

    def _validate_and_add_requests(
        self,
1429
        prompts: Union[PromptType, Sequence[PromptType]],
1430
1431
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1432
        *,
1433
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1434
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1435
        prompt_adapter_request: Optional[PromptAdapterRequest],
1436
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1437
        guided_options: Optional[GuidedDecodingRequest] = None,
1438
        priority: Optional[list[int]] = None,
1439
    ) -> None:
1440
1441
1442
1443
1444
1445
1446
1447
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1448
        if isinstance(prompts, (str, dict)):
1449
            # Convert a single prompt to a list.
1450
            prompts = [prompts]
1451

1452
        num_requests = len(prompts)
1453
        if isinstance(params, Sequence) and len(params) != num_requests:
1454
            raise ValueError("The lengths of prompts and params "
1455
                             "must be the same.")
1456
        if isinstance(lora_request,
1457
                      Sequence) and len(lora_request) != num_requests:
1458
1459
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1460

1461
        for sp in params if isinstance(params, Sequence) else (params, ):
1462
            if isinstance(sp, SamplingParams):
1463
                self._add_guided_params(sp, guided_options)
1464
1465
1466

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1467

Zhuohan Li's avatar
Zhuohan Li committed
1468
        # Add requests to the engine.
1469
1470
        it = prompts
        if use_tqdm:
1471
1472
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1473
1474

        for i, prompt in enumerate(it):
1475
            self._add_request(
1476
                prompt,
1477
                params[i] if isinstance(params, Sequence) else params,
1478
                tokenization_kwargs=tokenization_kwargs,
1479
1480
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1481
                prompt_adapter_request=prompt_adapter_request,
1482
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1483
            )
1484

1485
    def _add_request(
nunjunj's avatar
nunjunj committed
1486
        self,
1487
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1488
        params: Union[SamplingParams, PoolingParams],
1489
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1490
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1491
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1492
        priority: int = 0,
1493
1494
    ) -> None:
        request_id = str(next(self.request_counter))
1495
1496
        self.llm_engine.add_request(
            request_id,
1497
            prompt,
1498
1499
            params,
            lora_request=lora_request,
1500
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1501
            prompt_adapter_request=prompt_adapter_request,
1502
            priority=priority,
nunjunj's avatar
nunjunj committed
1503
        )
1504

1505
    def _add_guided_params(
1506
1507
1508
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1509
1510
1511
1512
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1513
            raise ValueError("Cannot set both guided_options_request and "
1514
1515
1516
1517
1518
1519
1520
1521
1522
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1523
1524
1525
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1526
1527
        return params

1528
    def _run_engine(
1529
1530
1531
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1532
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1533
1534
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1535
            num_requests = self.llm_engine.get_num_unfinished_requests()
1536
1537
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1538
1539
1540
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1541
1542
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1543
            )
1544

Zhuohan Li's avatar
Zhuohan Li committed
1545
        # Run the engine.
1546
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1547
1548
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1549
1550
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1551
            for output in step_outputs:
1552
                if output.finished:
1553
1554
                    outputs.append(output)
                    if use_tqdm:
1555
1556
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1557
                            n = len(output.outputs)
1558
                            assert output.prompt_token_ids is not None
1559
                            total_in_toks += len(output.prompt_token_ids) * n
1560
1561
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1562
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1563
1564
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1565
1566
1567
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1568
                            pbar.update(n)
1569
1570
                        else:
                            pbar.update(1)
1571

1572
1573
        if use_tqdm:
            pbar.close()
1574
1575
1576
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1577
        return sorted(outputs, key=lambda x: int(x.request_id))