llm.py 67.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from tqdm.auto import tqdm
14
from typing_extensions import TypeVar, deprecated
15

16
17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
18
19
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
20
21
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
31
from vllm.entrypoints.utils import _validate_truncation_size
32
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
33
from vllm.inputs.parse import parse_and_batch_prompt
34
from vllm.logger import init_logger
35
from vllm.lora.request import LoRARequest
36
37
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
38
from vllm.model_executor.layers.quantization import QuantizationMethods
39
40
41
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
42
from vllm.pooling_params import PoolingParams
43
from vllm.prompt_adapter.request import PromptAdapterRequest
44
45
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
46
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
47
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
48
from vllm.usage.usage_lib import UsageContext
49
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
50

51
52
53
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

54
55
logger = init_logger(__name__)

56
57
_R = TypeVar("_R", default=Any)

58
59

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
70
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
71
72
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
73
74
75
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
76
77
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
78
79
80
81
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
89
        quantization: The method used to quantize the model weights. Currently,
90
            we support "awq", "gptq", and "fp8" (experimental).
91
92
93
94
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
95
96
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
97
98
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
99
100
101
102
103
104
105
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
106
107
108
109
110
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
111
112
113
114
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
115
116
117
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
118
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
119
            When a sequence has context length larger than this, we fall back
120
121
122
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
123
124
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
125
126
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
127
        hf_token: The token to use as HTTP bearer authorization for remote files
128
            . If `True`, will use the token generated when running
129
            `huggingface-cli login` (stored in `~/.huggingface`).
130
131
132
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
133
134
135
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
136
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
137

138
139
    Note:
        This class is intended to be used for offline inference. For online
140
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
141
    """
142

143
    DEPRECATE_LEGACY: ClassVar[bool] = True
144
145
146
147
148
149
150
151
152
153
154
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

155
156
157
    def __init__(
        self,
        model: str,
158
159
        *,
        task: TaskOption = "auto",
160
        tokenizer: Optional[str] = None,
161
        tokenizer_mode: TokenizerMode = "auto",
162
        skip_tokenizer_init: bool = False,
163
        trust_remote_code: bool = False,
164
        allowed_local_media_path: str = "",
165
        tensor_parallel_size: int = 1,
166
167
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
168
        revision: Optional[str] = None,
169
        tokenizer_revision: Optional[str] = None,
170
        seed: Optional[int] = None,
171
        gpu_memory_utilization: float = 0.9,
172
        swap_space: float = 4,
173
        cpu_offload_gb: float = 0,
174
        enforce_eager: bool = False,
175
        max_seq_len_to_capture: int = 8192,
176
        disable_custom_all_reduce: bool = False,
177
        disable_async_output_proc: bool = False,
178
        hf_token: Optional[Union[bool, str]] = None,
179
        hf_overrides: Optional[HfOverrides] = None,
180
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
181
        override_pooler_config: Optional[PoolerConfig] = None,
182
183
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
184
185
        **kwargs,
    ) -> None:
186
        """LLM constructor."""
187

188
189
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
190

191
192
193
194
195
196
197
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

198
199
200
        if hf_overrides is None:
            hf_overrides = {}

201
        if compilation_config is not None:
202
203
204
205
206
207
208
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
209
210
            else:
                compilation_config_instance = compilation_config
211
        else:
212
            compilation_config_instance = CompilationConfig()
213

Zhuohan Li's avatar
Zhuohan Li committed
214
        engine_args = EngineArgs(
215
            model=model,
216
            task=task,
217
            tokenizer=tokenizer,
218
            tokenizer_mode=tokenizer_mode,
219
            skip_tokenizer_init=skip_tokenizer_init,
220
            trust_remote_code=trust_remote_code,
221
            allowed_local_media_path=allowed_local_media_path,
222
223
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
224
            quantization=quantization,
225
            revision=revision,
226
            tokenizer_revision=tokenizer_revision,
227
228
229
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
230
            cpu_offload_gb=cpu_offload_gb,
231
            enforce_eager=enforce_eager,
232
            max_seq_len_to_capture=max_seq_len_to_capture,
233
            disable_custom_all_reduce=disable_custom_all_reduce,
234
            disable_async_output_proc=disable_async_output_proc,
235
            hf_token=hf_token,
236
            hf_overrides=hf_overrides,
237
            mm_processor_kwargs=mm_processor_kwargs,
238
            override_pooler_config=override_pooler_config,
239
            compilation_config=compilation_config_instance,
240
241
            **kwargs,
        )
242
243
244
245
246

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
247

248
        self.request_counter = Counter()
249
        self.default_sampling_params: Union[dict[str, Any], None] = None
250

251
252
253
254
255
256
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
257
258

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
259
        tokenizer_group = self.llm_engine.get_tokenizer_group()
260

261
262
263
264
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
265
            tokenizer_group.tokenizer = tokenizer
266
        else:
267
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
268

269
    def get_default_sampling_params(self) -> SamplingParams:
270
271
272
273
274
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
275
276
        return SamplingParams()

277
278
279
280
281
282
283
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
284
        *,
285
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
286
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
287
288
289
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
290
    ) -> list[RequestOutput]:
291
292
        ...

293
    @overload  # LEGACY: single (prompt + optional token ids)
294
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
295
296
297
298
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
299
300
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
301
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
302
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
303
304
305
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
306
    ) -> list[RequestOutput]:
307
308
309
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
310
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
311
312
    def generate(
        self,
313
        prompts: list[str],
314
        sampling_params: Optional[Union[SamplingParams,
315
316
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
317
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
318
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
319
320
321
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
322
    ) -> list[RequestOutput]:
323
324
325
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
326
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
327
328
329
330
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
331
                                        list[SamplingParams]]] = None,
332
        *,
333
        prompt_token_ids: list[int],
334
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
335
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
336
337
338
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
339
    ) -> list[RequestOutput]:
340
341
342
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
343
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
344
345
    def generate(
        self,
346
        prompts: Optional[list[str]] = None,
347
        sampling_params: Optional[Union[SamplingParams,
348
                                        list[SamplingParams]]] = None,
349
        *,
350
        prompt_token_ids: list[list[int]],
351
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
352
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
353
354
355
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
356
    ) -> list[RequestOutput]:
357
358
359
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
360
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
361
362
363
364
    def generate(
        self,
        prompts: None,
        sampling_params: None,
365
        prompt_token_ids: Union[list[int], list[list[int]]],
366
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
367
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
368
369
370
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
371
    ) -> list[RequestOutput]:
372
373
        ...

nunjunj's avatar
nunjunj committed
374
375
376
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
377
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
378
    )
379
380
    def generate(
        self,
381
        prompts: Union[Union[PromptType, Sequence[PromptType]],
382
                       Optional[Union[str, list[str]]]] = None,
383
384
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
385
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
386
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
387
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
388
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
389
        guided_options_request: Optional[Union[LLMGuidedOptions,
390
                                               GuidedDecodingRequest]] = None,
391
392
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
        """Generates the completions for the input prompts.

395
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
399
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
400
            prompts: The prompts to the LLM. You may pass a sequence of prompts
401
                for batch inference. See [PromptType][vllm.inputs.PromptType]
402
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
403
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
404
405
406
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
407
                prompts and it is paired one by one with the prompt.
408
409
410
411
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
412
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
413
            prompt_adapter_request: Prompt Adapter request to use for
414
                generation, if any.
415
416
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
417
418

        Returns:
419
            A list of `RequestOutput` objects containing the
420
            generated completions in the same order as the input prompts.
421

422
423
424
425
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
426
        """
427
        runner_type = self.llm_engine.model_config.runner_type
428
        if runner_type not in ["generate", "transcription"]:
429
            messages = [
430
                "LLM.generate() is only supported for (conditional) generation "
431
432
433
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

434
435
436
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
437
                messages.append(
438
439
440
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
441
442

            raise ValueError(" ".join(messages))
443

444
        if prompt_token_ids is not None:
445
            parsed_prompts = self._convert_v1_inputs(
446
                prompts=cast(Optional[Union[str, list[str]]], prompts),
447
448
449
                prompt_token_ids=prompt_token_ids,
            )
        else:
450
451
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
452

453
454
455
456
457
458
459
460
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

461
462
        if sampling_params is None:
            # Use default sampling params.
463
            sampling_params = self.get_default_sampling_params()
464

465
        self._validate_and_add_requests(
466
            prompts=parsed_prompts,
467
            params=sampling_params,
468
            use_tqdm=use_tqdm,
469
            lora_request=lora_request,
470
            prompt_adapter_request=prompt_adapter_request,
471
            guided_options=guided_options_request,
472
473
            priority=priority,
        )
474

475
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
476
        return self.engine_class.validate_outputs(outputs, RequestOutput)
477

478
    def collective_rpc(self,
479
                       method: Union[str, Callable[..., _R]],
480
                       timeout: Optional[float] = None,
481
482
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
483
484
485
486
487
488
489
490
491
492
493
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
494
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
495
496
497
498
499
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
500

501
502
503
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
504
        """
505
506

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
507
508

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
509
        """
510
511
        Run a function directly on the model inside each worker,
        returning the result for each of them.
512
        """
513
514
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
515

516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

532
533
    def beam_search(
        self,
534
        prompts: list[Union[TokensPrompt, TextPrompt]],
535
        params: BeamSearchParams,
536
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
537
    ) -> list[BeamSearchOutput]:
538
539
540
541
542
543
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
544
            params: The beam search parameters.
545
            lora_request: LoRA request to use for generation, if any.
546
        """
547
548
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
549
550
551
552
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
553
554
        length_penalty = params.length_penalty

555
556
557
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

558
559
560
561
        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
562

563
564
565
566
567
568
569
570
571
572
573
574
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
575

576
577
578
579
580
581
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
582
                                            temperature=temperature)
583
        instances: list[BeamSearchInstance] = []
584

585
        for lora_req, prompt in zip(lora_requests, prompts):
586
587
588
589
590
591
592
593
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

594
595
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
596
597
598
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
599

600
            instances.append(
601
602
603
604
605
606
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
607
608

        for _ in range(max_tokens):
609
            all_beams: list[BeamSearchSequence] = list(
610
611
612
613
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
614
            instance_start_and_end: list[tuple[int, int]] = list(
615
616
617
618
619
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

620
621
622
623
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
624
625
626
627
628

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
629
630
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
647
                                logprobs=current_beam.logprobs + [logprobs],
648
                                lora_request=current_beam.lora_request,
649
                                cum_logprob=current_beam.cum_logprob +
650
651
652
653
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
654
655
656
657
658
659
660

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
661
                                      key=sort_beams_key,
662
663
664
665
666
667
668
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
669
                                      key=sort_beams_key,
670
671
672
673
674
675
676
677
678
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
679
680
    def chat(
        self,
681
682
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
683
        sampling_params: Optional[Union[SamplingParams,
684
                                        list[SamplingParams]]] = None,
685
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
686
687
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
688
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
689
        add_generation_prompt: bool = True,
690
        continue_final_message: bool = False,
691
        tools: Optional[list[dict[str, Any]]] = None,
692
        chat_template_kwargs: Optional[dict[str, Any]] = None,
693
694
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
695
        """
696
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
697

698
        The chat conversation is converted into a text prompt using the
699
        tokenizer and calls the [generate][] method to generate the
700
701
702
703
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
704
705

        Args:
706
707
            messages: A list of conversations or a single conversation.

708
709
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
710

nunjunj's avatar
nunjunj committed
711
712
713
714
715
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
716
717
718
719
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
720
721
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
722
                If not provided, the model's default chat template will be used.
723
724
            chat_template_content_format: The format to render message content.

725
726
727
728
729
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
730

731
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
732
                to each message.
733
            continue_final_message: If True, continues the final message in
734
                the conversation instead of starting a new one. Cannot be
735
                `True` if `add_generation_prompt` is also `True`.
736
737
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
738
739
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
740
741

        Returns:
742
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
743
744
            responses in the same order as the input messages.
        """
745
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
746

747
748
        # Handle multi and single conversations
        if is_list_of(messages, list):
749
750
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
751
                                    messages)
752
        else:
753
            # messages is list[...]
754
            list_of_messages = [
755
                cast(list[ChatCompletionMessageParam], messages)
756
            ]
757

758
        tokenizer = self.get_tokenizer(lora_request)
759
760
761
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
762
            tools,
763
764
            chat_template_content_format,
            tokenizer,
765
            model_config=model_config,
766
767
        )

768
769
770
771
772
773
774
775
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

776
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
777
778

        for msgs in list_of_messages:
779
780
781
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
782
            conversation, mm_data = parse_chat_messages(
783
784
785
786
787
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
788
789

            if isinstance(tokenizer, MistralTokenizer):
790
                prompt_token_ids = apply_mistral_chat_template(
791
792
                    tokenizer,
                    messages=msgs,
793
                    **_chat_template_kwargs,
794
795
                )
            else:
796
                prompt_str = apply_hf_chat_template(
797
                    tokenizer=tokenizer,
798
                    conversation=conversation,
799
                    model_config=model_config,
800
                    **_chat_template_kwargs,
801
                )
802
803
804
805
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
806

807
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
808
809
810
811

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

812
813
814
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

815
            prompts.append(prompt)
816

nunjunj's avatar
nunjunj committed
817
        return self.generate(
818
            prompts,
819
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
820
821
822
823
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

824
825
826
827
828
829
830
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
831
        *,
832
        truncate_prompt_tokens: Optional[int] = None,
833
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
834
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
835
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
836
    ) -> list[PoolingRequestOutput]:
837
838
        ...

839
    @overload  # LEGACY: single (prompt + optional token ids)
840
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
841
842
843
844
845
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
846
        prompt_token_ids: Optional[list[int]] = None,
847
        truncate_prompt_tokens: Optional[int] = None,
848
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
849
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
850
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
851
    ) -> list[PoolingRequestOutput]:
852
        ...
853

854
    @overload  # LEGACY: multi (prompt + optional token ids)
855
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
856
857
    def encode(
        self,
858
        prompts: list[str],
859
        pooling_params: Optional[Union[PoolingParams,
860
                                       Sequence[PoolingParams]]] = None,
861
        prompt_token_ids: Optional[list[list[int]]] = None,
862
        truncate_prompt_tokens: Optional[int] = None,
863
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
864
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
865
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
866
    ) -> list[PoolingRequestOutput]:
867
868
869
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
870
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
871
872
873
874
875
876
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
877
        prompt_token_ids: list[int],
878
        truncate_prompt_tokens: Optional[int] = None,
879
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
880
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
881
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
882
    ) -> list[PoolingRequestOutput]:
883
884
885
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
886
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
887
888
    def encode(
        self,
889
        prompts: Optional[list[str]] = None,
890
891
892
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
893
        prompt_token_ids: list[list[int]],
894
        truncate_prompt_tokens: Optional[int] = None,
895
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
896
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
897
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
898
    ) -> list[PoolingRequestOutput]:
899
900
901
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
902
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
903
904
905
906
    def encode(
        self,
        prompts: None,
        pooling_params: None,
907
        prompt_token_ids: Union[list[int], list[list[int]]],
908
        truncate_prompt_tokens: Optional[int] = None,
909
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
910
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
911
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
912
    ) -> list[PoolingRequestOutput]:
913
914
        ...

nunjunj's avatar
nunjunj committed
915
916
917
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
918
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
919
    )
920
921
    def encode(
        self,
922
        prompts: Union[Union[PromptType, Sequence[PromptType]],
923
                       Optional[Union[str, list[str]]]] = None,
924
925
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
926
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
927
        truncate_prompt_tokens: Optional[int] = None,
928
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
929
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
930
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
931
    ) -> list[PoolingRequestOutput]:
932
933
        """Apply pooling to the hidden states corresponding to the input
        prompts.
934

935
        This class automatically batches the given prompts, considering
936
937
938
939
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
940
            prompts: The prompts to the LLM. You may pass a sequence of prompts
941
                for batch inference. See [PromptType][vllm.inputs.PromptType]
942
                for more details about the format of each prompts.
943
944
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
945
946
947
948
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
949
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
950
            prompt_adapter_request: Prompt Adapter request to use for
951
                generation, if any.
952
953

        Returns:
954
            A list of `PoolingRequestOutput` objects containing the
955
            pooled hidden states in the same order as the input prompts.
956

957
958
959
960
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
961
        """
962
963
964
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
965

966
967
968
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
969
                messages.append(
970
971
972
973
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
974
975

            raise ValueError(" ".join(messages))
976

977
        if prompt_token_ids is not None:
978
            parsed_prompts = self._convert_v1_inputs(
979
                prompts=cast(Optional[Union[str, list[str]]], prompts),
980
981
982
                prompt_token_ids=prompt_token_ids,
            )
        else:
983
984
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
985

986
987
988
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
989
990
991
992
993
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
994

995
996
997
998
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

999
        self._validate_and_add_requests(
1000
            prompts=parsed_prompts,
1001
            params=pooling_params,
1002
            use_tqdm=use_tqdm,
1003
            lora_request=lora_request,
1004
            tokenization_kwargs=tokenization_kwargs,
1005
            prompt_adapter_request=prompt_adapter_request,
1006
1007
        )

1008
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1009
        return self.engine_class.validate_outputs(outputs,
1010
                                                  PoolingRequestOutput)
1011

1012
1013
1014
1015
1016
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1017
        truncate_prompt_tokens: Optional[int] = None,
1018
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1019
1020
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1021
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1022
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1023
    ) -> list[EmbeddingRequestOutput]:
1024
1025
1026
1027
1028
1029
1030
1031
1032
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1033
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1034
                for more details about the format of each prompts.
1035
1036
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1037
1038
1039
1040
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1041
1042
1043
1044
1045
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1046
            A list of `EmbeddingRequestOutput` objects containing the
1047
1048
1049
1050
1051
1052
1053
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1054
                            truncate_prompt_tokens=truncate_prompt_tokens,
1055
                            use_tqdm=use_tqdm,
1056
                            pooling_params=pooling_params,
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1067
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1068
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1069
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1070
    ) -> list[ClassificationRequestOutput]:
1071
1072
1073
1074
1075
1076
1077
1078
1079
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1080
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1081
                for more details about the format of each prompts.
1082
1083
1084
1085
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1086
1087
1088
1089
1090
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1091
            A list of `ClassificationRequestOutput` objects containing the
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1105
1106
1107
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1108
1109
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1110
        truncate_prompt_tokens: Optional[int] = None,
1111
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1112
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1113
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1114
    ) -> list[ScoringRequestOutput]:
1115

1116
        encoded_output: list[PoolingRequestOutput] = self.encode(
1117
            text_1 + text_2,
1118
            truncate_prompt_tokens=truncate_prompt_tokens,
1119
1120
1121
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1122

1123
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1124
            0:len(text_1)]
1125
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1126
            len(text_1):]
1127
1128
1129
1130

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1131
1132
1133
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1134
1135
1136
1137
1138
1139
1140

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1141
        tokenizer: AnyTokenizer,
1142
1143
        text_1: list[str],
        text_2: list[str],
1144
        truncate_prompt_tokens: Optional[int] = None,
1145
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1146
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1147
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1148
    ) -> list[ScoringRequestOutput]:
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1161
        tokenization_kwargs: dict[str, Any] = {}
1162
1163
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1179
            use_tqdm=use_tqdm,
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1190
1191
1192
1193
1194
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1195
        *,
1196
        truncate_prompt_tokens: Optional[int] = None,
1197
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1198
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1199
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1200
    ) -> list[ScoringRequestOutput]:
1201
        """Generate similarity scores for all pairs `<text,text_pair>`.
1202

1203
1204
1205
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1206
        The input pairs are used to build a list of prompts for the
1207
1208
1209
1210
1211
1212
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1213
                case it has to have the same length as the `text_2` list
1214
            text_2: The texts to pair with the query to form the input
1215
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1216
                more details about the format of each prompts.
1217
1218
1219
1220
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1221
1222
1223
1224
1225
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1226
            A list of `ScoringRequestOutput` objects containing the
1227
1228
            generated scores in the same order as the input prompts.
        """
1229
1230
1231
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1232

1233
1234
1235
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1236
                messages.append(
1237
1238
1239
1240
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1241
1242
1243

            raise ValueError(" ".join(messages))

1244
        if self.llm_engine.model_config.task not in ("embed", "score"):
1245
            raise ValueError(
1246
                "Score API is only enabled for `--task embed or --task score`")
1247
1248
1249
1250

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1251
1252
        tokenizer = self.llm_engine.get_tokenizer()

1253
1254
1255
1256
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1257
                                     "supported for scoring")
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1269
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1270
1271
1272
1273

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1274
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1275

1276
        _validate_score_input_lens(input_text_1, input_text_2)
1277

1278
        if self.llm_engine.model_config.is_cross_encoder:
1279
1280
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1281
1282
1283
1284
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1285
1286
1287
1288
1289
1290
1291
1292
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1293

1294
1295
1296
1297
1298
1299
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1300
1301
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1302

1303
1304
1305
1306
1307
1308
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1321
        """
1322
        self.reset_prefix_cache()
1323
1324
        self.llm_engine.sleep(level=level)

1325
    def wake_up(self, tags: Optional[list[str]] = None):
1326
        """
1327
        Wake up the engine from sleep mode. See the [sleep][] method
1328
1329
1330
1331
1332
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1333
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1334
1335
1336
1337
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1338

1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1353
1354
    # LEGACY
    def _convert_v1_inputs(
1355
        self,
1356
1357
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1358
1359
    ):
        # skip_tokenizer_init is now checked in engine
1360

1361
1362
1363
1364
1365
1366
1367
1368
1369
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1370
1371
1372
1373
1374
1375
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1376
1377
        if prompts is not None:
            num_requests = len(prompts)
1378
        elif prompt_token_ids is not None:
1379
            num_requests = len(prompt_token_ids)
1380
        parsed_prompts: list[PromptType] = []
1381
        for i in range(num_requests):
1382
            item: PromptType
1383

1384
            if prompts is not None:
1385
1386
1387
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1388
            else:
1389
                raise AssertionError
1390

1391
            parsed_prompts.append(item)
1392

1393
        return parsed_prompts
1394
1395
1396

    def _validate_and_add_requests(
        self,
1397
        prompts: Union[PromptType, Sequence[PromptType]],
1398
1399
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1400
        *,
1401
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1402
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1403
        prompt_adapter_request: Optional[PromptAdapterRequest],
1404
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1405
        guided_options: Optional[GuidedDecodingRequest] = None,
1406
        priority: Optional[list[int]] = None,
1407
    ) -> None:
1408
1409
1410
1411
1412
1413
1414
1415
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1416
        if isinstance(prompts, (str, dict)):
1417
            # Convert a single prompt to a list.
1418
            prompts = [prompts]
1419

1420
        num_requests = len(prompts)
1421
1422
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1423
                             "must be the same.")
1424
1425
1426
1427
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1428

1429
1430
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1431
                self._add_guided_params(sp, guided_options)
1432
1433
1434

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1435

Zhuohan Li's avatar
Zhuohan Li committed
1436
        # Add requests to the engine.
1437
1438
        it = prompts
        if use_tqdm:
1439
1440
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1441
1442

        for i, prompt in enumerate(it):
1443
            self._add_request(
1444
                prompt,
1445
                params[i] if isinstance(params, Sequence) else params,
1446
                tokenization_kwargs=tokenization_kwargs,
1447
1448
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1449
                prompt_adapter_request=prompt_adapter_request,
1450
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1451
            )
1452

1453
    def _add_request(
nunjunj's avatar
nunjunj committed
1454
        self,
1455
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1456
        params: Union[SamplingParams, PoolingParams],
1457
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1458
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1459
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1460
        priority: int = 0,
1461
1462
    ) -> None:
        request_id = str(next(self.request_counter))
1463
1464
        self.llm_engine.add_request(
            request_id,
1465
            prompt,
1466
1467
            params,
            lora_request=lora_request,
1468
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1469
            prompt_adapter_request=prompt_adapter_request,
1470
            priority=priority,
nunjunj's avatar
nunjunj committed
1471
        )
1472

1473
    def _add_guided_params(
1474
1475
1476
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1477
1478
1479
1480
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1481
            raise ValueError("Cannot set both guided_options_request and "
1482
1483
1484
1485
1486
1487
1488
1489
1490
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1491
1492
1493
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1494
1495
        return params

1496
    def _run_engine(
1497
1498
1499
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1500
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1501
1502
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1503
            num_requests = self.llm_engine.get_num_unfinished_requests()
1504
1505
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1506
1507
1508
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1509
1510
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1511
            )
1512

Zhuohan Li's avatar
Zhuohan Li committed
1513
        # Run the engine.
1514
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1515
1516
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1517
1518
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1519
            for output in step_outputs:
1520
                if output.finished:
1521
1522
                    outputs.append(output)
                    if use_tqdm:
1523
1524
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1525
                            n = len(output.outputs)
1526
                            assert output.prompt_token_ids is not None
1527
                            total_in_toks += len(output.prompt_token_ids) * n
1528
1529
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1530
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1531
1532
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1533
1534
1535
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1536
                            pbar.update(n)
1537
1538
                        else:
                            pbar.update(1)
1539

1540
1541
        if use_tqdm:
            pbar.close()
1542
1543
1544
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1545
        return sorted(outputs, key=lambda x: int(x.request_id))