llm.py 67.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from tqdm.auto import tqdm
14
from typing_extensions import TypeVar, deprecated
15

16
17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
18
19
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
20
21
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
31
from vllm.entrypoints.utils import _validate_truncation_size
32
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
33
from vllm.inputs.parse import parse_and_batch_prompt
34
from vllm.logger import init_logger
35
from vllm.lora.request import LoRARequest
36
37
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
38
from vllm.model_executor.layers.quantization import QuantizationMethods
39
40
41
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
42
from vllm.pooling_params import PoolingParams
43
from vllm.prompt_adapter.request import PromptAdapterRequest
44
45
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
46
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
47
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
48
from vllm.usage.usage_lib import UsageContext
49
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
50

51
52
53
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

54
55
logger = init_logger(__name__)

56
57
_R = TypeVar("_R", default=Any)

58
59

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
60
61
62
63
64
65
66
67
68
69
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
70
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
71
72
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
73
74
75
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
76
77
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
78
79
80
81
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
82
83
84
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
88
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
89
        quantization: The method used to quantize the model weights. Currently,
90
            we support "awq", "gptq", and "fp8" (experimental).
91
92
93
94
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
95
96
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
97
98
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
99
100
101
102
103
104
105
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
106
107
108
109
110
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
111
112
113
114
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
115
116
117
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
118
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
119
            When a sequence has context length larger than this, we fall back
120
121
122
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
123
124
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
125
126
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
127
        hf_token: The token to use as HTTP bearer authorization for remote files
128
            . If `True`, will use the token generated when running
129
            `huggingface-cli login` (stored in `~/.huggingface`).
130
131
132
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
133
134
135
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
136
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
137

138
139
    Note:
        This class is intended to be used for offline inference. For online
140
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
141
    """
142

143
    DEPRECATE_LEGACY: ClassVar[bool] = True
144
145
146
147
148
149
150
151
152
153
154
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

155
156
157
    def __init__(
        self,
        model: str,
158
159
        *,
        task: TaskOption = "auto",
160
        tokenizer: Optional[str] = None,
161
        tokenizer_mode: TokenizerMode = "auto",
162
        skip_tokenizer_init: bool = False,
163
        trust_remote_code: bool = False,
164
        allowed_local_media_path: str = "",
165
        tensor_parallel_size: int = 1,
166
167
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
168
        revision: Optional[str] = None,
169
        tokenizer_revision: Optional[str] = None,
170
        seed: Optional[int] = None,
171
        gpu_memory_utilization: float = 0.9,
172
        swap_space: float = 4,
173
        cpu_offload_gb: float = 0,
174
        enforce_eager: bool = False,
175
        max_seq_len_to_capture: int = 8192,
176
        disable_custom_all_reduce: bool = False,
177
        disable_async_output_proc: bool = False,
178
        hf_token: Optional[Union[bool, str]] = None,
179
        hf_overrides: Optional[HfOverrides] = None,
180
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
181
        override_pooler_config: Optional[PoolerConfig] = None,
182
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
183
184
        **kwargs,
    ) -> None:
185
        """LLM constructor."""
186

187
188
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
189

190
191
192
193
194
195
196
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

197
198
199
        if hf_overrides is None:
            hf_overrides = {}

200
        if compilation_config is not None:
201
202
203
204
205
206
207
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
208
209
            else:
                compilation_config_instance = compilation_config
210
        else:
211
            compilation_config_instance = CompilationConfig()
212

Zhuohan Li's avatar
Zhuohan Li committed
213
        engine_args = EngineArgs(
214
            model=model,
215
            task=task,
216
            tokenizer=tokenizer,
217
            tokenizer_mode=tokenizer_mode,
218
            skip_tokenizer_init=skip_tokenizer_init,
219
            trust_remote_code=trust_remote_code,
220
            allowed_local_media_path=allowed_local_media_path,
221
222
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
223
            quantization=quantization,
224
            revision=revision,
225
            tokenizer_revision=tokenizer_revision,
226
227
228
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
229
            cpu_offload_gb=cpu_offload_gb,
230
            enforce_eager=enforce_eager,
231
            max_seq_len_to_capture=max_seq_len_to_capture,
232
            disable_custom_all_reduce=disable_custom_all_reduce,
233
            disable_async_output_proc=disable_async_output_proc,
234
            hf_token=hf_token,
235
            hf_overrides=hf_overrides,
236
            mm_processor_kwargs=mm_processor_kwargs,
237
            override_pooler_config=override_pooler_config,
238
            compilation_config=compilation_config_instance,
239
240
            **kwargs,
        )
241
242
243
244
245

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
246

247
        self.request_counter = Counter()
248
        self.default_sampling_params: Union[dict[str, Any], None] = None
249

250
251
252
253
254
255
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
256
257

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
258
        tokenizer_group = self.llm_engine.get_tokenizer_group()
259

260
261
262
263
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
264
            tokenizer_group.tokenizer = tokenizer
265
        else:
266
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
267

268
    def get_default_sampling_params(self) -> SamplingParams:
269
270
271
272
273
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
274
275
        return SamplingParams()

276
277
278
279
280
281
282
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
283
        *,
284
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
285
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
286
287
288
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
289
    ) -> list[RequestOutput]:
290
291
        ...

292
    @overload  # LEGACY: single (prompt + optional token ids)
293
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
294
295
296
297
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
298
299
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
300
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
301
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
302
303
304
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
305
    ) -> list[RequestOutput]:
306
307
308
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
309
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
310
311
    def generate(
        self,
312
        prompts: list[str],
313
        sampling_params: Optional[Union[SamplingParams,
314
315
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
316
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
317
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
318
319
320
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
321
    ) -> list[RequestOutput]:
322
323
324
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
325
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
326
327
328
329
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
330
                                        list[SamplingParams]]] = None,
331
        *,
332
        prompt_token_ids: list[int],
333
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
334
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
335
336
337
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
338
    ) -> list[RequestOutput]:
339
340
341
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
342
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
343
344
    def generate(
        self,
345
        prompts: Optional[list[str]] = None,
346
        sampling_params: Optional[Union[SamplingParams,
347
                                        list[SamplingParams]]] = None,
348
        *,
349
        prompt_token_ids: list[list[int]],
350
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
351
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
352
353
354
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
355
    ) -> list[RequestOutput]:
356
357
358
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
359
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
360
361
362
363
    def generate(
        self,
        prompts: None,
        sampling_params: None,
364
        prompt_token_ids: Union[list[int], list[list[int]]],
365
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
366
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
367
368
369
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
370
    ) -> list[RequestOutput]:
371
372
        ...

nunjunj's avatar
nunjunj committed
373
374
375
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
376
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
377
    )
378
379
    def generate(
        self,
380
        prompts: Union[Union[PromptType, Sequence[PromptType]],
381
                       Optional[Union[str, list[str]]]] = None,
382
383
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
384
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
385
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
386
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
387
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
388
        guided_options_request: Optional[Union[LLMGuidedOptions,
389
                                               GuidedDecodingRequest]] = None,
390
391
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
        """Generates the completions for the input prompts.

394
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
397
398
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
399
            prompts: The prompts to the LLM. You may pass a sequence of prompts
400
                for batch inference. See [PromptType][vllm.inputs.PromptType]
401
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
402
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
403
404
405
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
406
                prompts and it is paired one by one with the prompt.
407
408
409
410
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
411
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
412
            prompt_adapter_request: Prompt Adapter request to use for
413
                generation, if any.
414
415
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
416
417

        Returns:
418
            A list of `RequestOutput` objects containing the
419
            generated completions in the same order as the input prompts.
420

421
422
423
424
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
425
        """
426
        runner_type = self.llm_engine.model_config.runner_type
427
        if runner_type not in ["generate", "transcription"]:
428
            messages = [
429
                "LLM.generate() is only supported for (conditional) generation "
430
431
432
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

433
434
435
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
436
                messages.append(
437
438
439
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
440
441

            raise ValueError(" ".join(messages))
442

443
        if prompt_token_ids is not None:
444
            parsed_prompts = self._convert_v1_inputs(
445
                prompts=cast(Optional[Union[str, list[str]]], prompts),
446
447
448
                prompt_token_ids=prompt_token_ids,
            )
        else:
449
450
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
451

452
453
454
455
456
457
458
459
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

460
461
        if sampling_params is None:
            # Use default sampling params.
462
            sampling_params = self.get_default_sampling_params()
463

464
        self._validate_and_add_requests(
465
            prompts=parsed_prompts,
466
            params=sampling_params,
467
            use_tqdm=use_tqdm,
468
            lora_request=lora_request,
469
            prompt_adapter_request=prompt_adapter_request,
470
            guided_options=guided_options_request,
471
472
            priority=priority,
        )
473

474
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
475
        return self.engine_class.validate_outputs(outputs, RequestOutput)
476

477
    def collective_rpc(self,
478
                       method: Union[str, Callable[..., _R]],
479
                       timeout: Optional[float] = None,
480
481
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
482
483
484
485
486
487
488
489
490
491
492
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
493
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
494
495
496
497
498
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
499

500
501
502
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
503
        """
504
505

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
506
507

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
508
        """
509
510
        Run a function directly on the model inside each worker,
        returning the result for each of them.
511
        """
512
513
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
514

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

531
532
    def beam_search(
        self,
533
        prompts: list[Union[TokensPrompt, TextPrompt]],
534
        params: BeamSearchParams,
535
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
536
    ) -> list[BeamSearchOutput]:
537
538
539
540
541
542
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
543
            params: The beam search parameters.
544
            lora_request: LoRA request to use for generation, if any.
545
        """
546
547
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
548
549
550
551
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
552
553
        length_penalty = params.length_penalty

554
555
556
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

557
558
559
560
        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
561

562
563
564
565
566
567
568
569
570
571
572
573
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
574

575
576
577
578
579
580
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
581
                                            temperature=temperature)
582
        instances: list[BeamSearchInstance] = []
583

584
        for lora_req, prompt in zip(lora_requests, prompts):
585
586
587
588
589
590
591
592
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

593
594
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
595
596
597
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
598

599
            instances.append(
600
601
602
603
604
605
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
606
607

        for _ in range(max_tokens):
608
            all_beams: list[BeamSearchSequence] = list(
609
610
611
612
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
613
            instance_start_and_end: list[tuple[int, int]] = list(
614
615
616
617
618
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

619
620
621
622
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
623
624
625
626
627

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
628
629
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
646
                                logprobs=current_beam.logprobs + [logprobs],
647
                                lora_request=current_beam.lora_request,
648
                                cum_logprob=current_beam.cum_logprob +
649
650
651
652
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
653
654
655
656
657
658
659

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
660
                                      key=sort_beams_key,
661
662
663
664
665
666
667
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
668
                                      key=sort_beams_key,
669
670
671
672
673
674
675
676
677
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
678
679
    def chat(
        self,
680
681
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
682
        sampling_params: Optional[Union[SamplingParams,
683
                                        list[SamplingParams]]] = None,
684
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
685
686
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
687
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
688
        add_generation_prompt: bool = True,
689
        continue_final_message: bool = False,
690
        tools: Optional[list[dict[str, Any]]] = None,
691
        chat_template_kwargs: Optional[dict[str, Any]] = None,
692
693
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
694
        """
695
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
696

697
        The chat conversation is converted into a text prompt using the
698
        tokenizer and calls the [generate][] method to generate the
699
700
701
702
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
703
704

        Args:
705
706
            messages: A list of conversations or a single conversation.

707
708
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
709

nunjunj's avatar
nunjunj committed
710
711
712
713
714
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
715
716
717
718
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
719
720
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
721
                If not provided, the model's default chat template will be used.
722
723
            chat_template_content_format: The format to render message content.

724
725
726
727
728
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
729

730
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
731
                to each message.
732
            continue_final_message: If True, continues the final message in
733
                the conversation instead of starting a new one. Cannot be
734
                `True` if `add_generation_prompt` is also `True`.
735
736
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
737
738
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
739
740

        Returns:
741
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
742
743
            responses in the same order as the input messages.
        """
744
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
745

746
747
        # Handle multi and single conversations
        if is_list_of(messages, list):
748
749
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
750
                                    messages)
751
        else:
752
            # messages is list[...]
753
            list_of_messages = [
754
                cast(list[ChatCompletionMessageParam], messages)
755
            ]
756

757
        tokenizer = self.get_tokenizer(lora_request)
758
759
760
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
761
            tools,
762
763
            chat_template_content_format,
            tokenizer,
764
            model_config=model_config,
765
766
        )

767
768
769
770
771
772
773
774
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

775
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
776
777

        for msgs in list_of_messages:
778
779
780
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
781
            conversation, mm_data = parse_chat_messages(
782
783
784
785
786
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
787
788

            if isinstance(tokenizer, MistralTokenizer):
789
                prompt_token_ids = apply_mistral_chat_template(
790
791
                    tokenizer,
                    messages=msgs,
792
                    **_chat_template_kwargs,
793
794
                )
            else:
795
                prompt_str = apply_hf_chat_template(
796
                    tokenizer=tokenizer,
797
                    conversation=conversation,
798
                    model_config=model_config,
799
                    **_chat_template_kwargs,
800
                )
801
802
803
804
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
805

806
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
807
808
809
810

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

811
812
813
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

814
            prompts.append(prompt)
815

nunjunj's avatar
nunjunj committed
816
        return self.generate(
817
            prompts,
818
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
819
820
821
822
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

823
824
825
826
827
828
829
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
830
        *,
831
        truncate_prompt_tokens: Optional[int] = None,
832
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
833
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
834
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
835
    ) -> list[PoolingRequestOutput]:
836
837
        ...

838
    @overload  # LEGACY: single (prompt + optional token ids)
839
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
840
841
842
843
844
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
845
        prompt_token_ids: Optional[list[int]] = None,
846
        truncate_prompt_tokens: Optional[int] = None,
847
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
848
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
849
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
850
    ) -> list[PoolingRequestOutput]:
851
        ...
852

853
    @overload  # LEGACY: multi (prompt + optional token ids)
854
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
855
856
    def encode(
        self,
857
        prompts: list[str],
858
        pooling_params: Optional[Union[PoolingParams,
859
                                       Sequence[PoolingParams]]] = None,
860
        prompt_token_ids: Optional[list[list[int]]] = None,
861
        truncate_prompt_tokens: Optional[int] = None,
862
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
863
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
864
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
865
    ) -> list[PoolingRequestOutput]:
866
867
868
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
869
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
870
871
872
873
874
875
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
876
        prompt_token_ids: list[int],
877
        truncate_prompt_tokens: Optional[int] = None,
878
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
879
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
880
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
881
    ) -> list[PoolingRequestOutput]:
882
883
884
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
885
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
886
887
    def encode(
        self,
888
        prompts: Optional[list[str]] = None,
889
890
891
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
892
        prompt_token_ids: list[list[int]],
893
        truncate_prompt_tokens: Optional[int] = None,
894
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
895
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
896
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
897
    ) -> list[PoolingRequestOutput]:
898
899
900
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
901
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
902
903
904
905
    def encode(
        self,
        prompts: None,
        pooling_params: None,
906
        prompt_token_ids: Union[list[int], list[list[int]]],
907
        truncate_prompt_tokens: Optional[int] = None,
908
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
909
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
910
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
911
    ) -> list[PoolingRequestOutput]:
912
913
        ...

nunjunj's avatar
nunjunj committed
914
915
916
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
917
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
918
    )
919
920
    def encode(
        self,
921
        prompts: Union[Union[PromptType, Sequence[PromptType]],
922
                       Optional[Union[str, list[str]]]] = None,
923
924
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
925
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
926
        truncate_prompt_tokens: Optional[int] = None,
927
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
928
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
929
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
930
    ) -> list[PoolingRequestOutput]:
931
932
        """Apply pooling to the hidden states corresponding to the input
        prompts.
933

934
        This class automatically batches the given prompts, considering
935
936
937
938
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
939
            prompts: The prompts to the LLM. You may pass a sequence of prompts
940
                for batch inference. See [PromptType][vllm.inputs.PromptType]
941
                for more details about the format of each prompts.
942
943
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
944
945
946
947
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
948
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
949
            prompt_adapter_request: Prompt Adapter request to use for
950
                generation, if any.
951
952

        Returns:
953
            A list of `PoolingRequestOutput` objects containing the
954
            pooled hidden states in the same order as the input prompts.
955

956
957
958
959
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
960
        """
961
962
963
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
964

965
966
967
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
968
                messages.append(
969
970
971
972
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
973
974

            raise ValueError(" ".join(messages))
975

976
        if prompt_token_ids is not None:
977
            parsed_prompts = self._convert_v1_inputs(
978
                prompts=cast(Optional[Union[str, list[str]]], prompts),
979
980
981
                prompt_token_ids=prompt_token_ids,
            )
        else:
982
983
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
984

985
986
987
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
988
989
990
991
992
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
993

994
995
996
997
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

998
        self._validate_and_add_requests(
999
            prompts=parsed_prompts,
1000
            params=pooling_params,
1001
            use_tqdm=use_tqdm,
1002
            lora_request=lora_request,
1003
            tokenization_kwargs=tokenization_kwargs,
1004
            prompt_adapter_request=prompt_adapter_request,
1005
1006
        )

1007
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1008
        return self.engine_class.validate_outputs(outputs,
1009
                                                  PoolingRequestOutput)
1010

1011
1012
1013
1014
1015
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1016
        truncate_prompt_tokens: Optional[int] = None,
1017
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1018
1019
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1020
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1021
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1022
    ) -> list[EmbeddingRequestOutput]:
1023
1024
1025
1026
1027
1028
1029
1030
1031
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1032
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1033
                for more details about the format of each prompts.
1034
1035
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1036
1037
1038
1039
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1040
1041
1042
1043
1044
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1045
            A list of `EmbeddingRequestOutput` objects containing the
1046
1047
1048
1049
1050
1051
1052
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1053
                            truncate_prompt_tokens=truncate_prompt_tokens,
1054
                            use_tqdm=use_tqdm,
1055
                            pooling_params=pooling_params,
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1066
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1067
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1068
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1069
    ) -> list[ClassificationRequestOutput]:
1070
1071
1072
1073
1074
1075
1076
1077
1078
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1079
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1080
                for more details about the format of each prompts.
1081
1082
1083
1084
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1085
1086
1087
1088
1089
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1090
            A list of `ClassificationRequestOutput` objects containing the
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1104
1105
1106
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1107
1108
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1109
        truncate_prompt_tokens: Optional[int] = None,
1110
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1111
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1112
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1113
    ) -> list[ScoringRequestOutput]:
1114

1115
        encoded_output: list[PoolingRequestOutput] = self.encode(
1116
            text_1 + text_2,
1117
            truncate_prompt_tokens=truncate_prompt_tokens,
1118
1119
1120
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1121

1122
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1123
            0:len(text_1)]
1124
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1125
            len(text_1):]
1126
1127
1128
1129

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1130
1131
1132
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1133
1134
1135
1136
1137
1138
1139

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1140
        tokenizer: AnyTokenizer,
1141
1142
        text_1: list[str],
        text_2: list[str],
1143
        truncate_prompt_tokens: Optional[int] = None,
1144
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1145
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1146
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1147
    ) -> list[ScoringRequestOutput]:
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1160
        tokenization_kwargs: dict[str, Any] = {}
1161
1162
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1178
            use_tqdm=use_tqdm,
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1189
1190
1191
1192
1193
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1194
        *,
1195
        truncate_prompt_tokens: Optional[int] = None,
1196
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1197
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1198
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1199
    ) -> list[ScoringRequestOutput]:
1200
        """Generate similarity scores for all pairs `<text,text_pair>`.
1201

1202
1203
1204
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1205
        The input pairs are used to build a list of prompts for the
1206
1207
1208
1209
1210
1211
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1212
                case it has to have the same length as the `text_2` list
1213
            text_2: The texts to pair with the query to form the input
1214
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1215
                more details about the format of each prompts.
1216
1217
1218
1219
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1220
1221
1222
1223
1224
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1225
            A list of `ScoringRequestOutput` objects containing the
1226
1227
            generated scores in the same order as the input prompts.
        """
1228
1229
1230
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1231

1232
1233
1234
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1235
                messages.append(
1236
1237
1238
1239
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1240
1241
1242

            raise ValueError(" ".join(messages))

1243
        if self.llm_engine.model_config.task not in ("embed", "score"):
1244
            raise ValueError(
1245
                "Score API is only enabled for `--task embed or --task score`")
1246
1247
1248
1249

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1250
1251
        tokenizer = self.llm_engine.get_tokenizer()

1252
1253
1254
1255
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1256
                                     "supported for scoring")
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1268
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1269
1270
1271
1272

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1273
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1274

1275
        _validate_score_input_lens(input_text_1, input_text_2)
1276

1277
        if self.llm_engine.model_config.is_cross_encoder:
1278
1279
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1280
1281
1282
1283
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1284
1285
1286
1287
1288
1289
1290
1291
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1292

1293
1294
1295
1296
1297
1298
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1299
1300
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1301

1302
1303
1304
1305
1306
1307
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1320
        """
1321
        self.reset_prefix_cache()
1322
1323
        self.llm_engine.sleep(level=level)

1324
    def wake_up(self, tags: Optional[list[str]] = None):
1325
        """
1326
        Wake up the engine from sleep mode. See the [sleep][] method
1327
1328
1329
1330
1331
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1332
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1333
1334
1335
1336
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1337

1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1352
1353
    # LEGACY
    def _convert_v1_inputs(
1354
        self,
1355
1356
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1357
1358
    ):
        # skip_tokenizer_init is now checked in engine
1359

1360
1361
1362
1363
1364
1365
1366
1367
1368
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1369
1370
1371
1372
1373
1374
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1375
1376
        if prompts is not None:
            num_requests = len(prompts)
1377
        elif prompt_token_ids is not None:
1378
            num_requests = len(prompt_token_ids)
1379
        parsed_prompts: list[PromptType] = []
1380
        for i in range(num_requests):
1381
            item: PromptType
1382

1383
            if prompts is not None:
1384
1385
1386
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1387
            else:
1388
                raise AssertionError
1389

1390
            parsed_prompts.append(item)
1391

1392
        return parsed_prompts
1393
1394
1395

    def _validate_and_add_requests(
        self,
1396
        prompts: Union[PromptType, Sequence[PromptType]],
1397
1398
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1399
        *,
1400
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1401
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1402
        prompt_adapter_request: Optional[PromptAdapterRequest],
1403
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1404
        guided_options: Optional[GuidedDecodingRequest] = None,
1405
        priority: Optional[list[int]] = None,
1406
    ) -> None:
1407
1408
1409
1410
1411
1412
1413
1414
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1415
        if isinstance(prompts, (str, dict)):
1416
            # Convert a single prompt to a list.
1417
            prompts = [prompts]
1418

1419
        num_requests = len(prompts)
1420
1421
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1422
                             "must be the same.")
1423
1424
1425
1426
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1427

1428
1429
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1430
                self._add_guided_params(sp, guided_options)
1431
1432
1433

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1434

Zhuohan Li's avatar
Zhuohan Li committed
1435
        # Add requests to the engine.
1436
1437
        it = prompts
        if use_tqdm:
1438
1439
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1440
1441

        for i, prompt in enumerate(it):
1442
            self._add_request(
1443
                prompt,
1444
                params[i] if isinstance(params, Sequence) else params,
1445
                tokenization_kwargs=tokenization_kwargs,
1446
1447
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1448
                prompt_adapter_request=prompt_adapter_request,
1449
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1450
            )
1451

1452
    def _add_request(
nunjunj's avatar
nunjunj committed
1453
        self,
1454
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1455
        params: Union[SamplingParams, PoolingParams],
1456
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1457
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1458
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1459
        priority: int = 0,
1460
1461
    ) -> None:
        request_id = str(next(self.request_counter))
1462
1463
        self.llm_engine.add_request(
            request_id,
1464
            prompt,
1465
1466
            params,
            lora_request=lora_request,
1467
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1468
            prompt_adapter_request=prompt_adapter_request,
1469
            priority=priority,
nunjunj's avatar
nunjunj committed
1470
        )
1471

1472
    def _add_guided_params(
1473
1474
1475
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1476
1477
1478
1479
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1480
            raise ValueError("Cannot set both guided_options_request and "
1481
1482
1483
1484
1485
1486
1487
1488
1489
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1490
1491
1492
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1493
1494
        return params

1495
    def _run_engine(
1496
1497
1498
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1499
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1500
1501
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1502
            num_requests = self.llm_engine.get_num_unfinished_requests()
1503
1504
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1505
1506
1507
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1508
1509
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1510
            )
1511

Zhuohan Li's avatar
Zhuohan Li committed
1512
        # Run the engine.
1513
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1514
1515
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1516
1517
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1518
            for output in step_outputs:
1519
                if output.finished:
1520
1521
                    outputs.append(output)
                    if use_tqdm:
1522
1523
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1524
                            n = len(output.outputs)
1525
                            assert output.prompt_token_ids is not None
1526
                            total_in_toks += len(output.prompt_token_ids) * n
1527
1528
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1529
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1530
1531
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1532
1533
1534
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1535
                            pbar.update(n)
1536
1537
                        else:
                            pbar.update(1)
1538

1539
1540
        if use_tqdm:
            pbar.close()
1541
1542
1543
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1544
        return sorted(outputs, key=lambda x: int(x.request_id))