"vllm/vscode:/vscode.git/clone" did not exist on "4f11b099a75eb737475d3fc18b48f2b526e5a5c4"
llm.py 60.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
33
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
34
35
36
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
37
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
39
40
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
41
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
43
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
44
from vllm.usage.usage_lib import UsageContext
45
46
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
47

48
49
logger = init_logger(__name__)

50
51
_R = TypeVar("_R", default=Any)

52
53

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
63
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
64
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
65
66
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
67
68
69
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
70
71
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
72
73
74
75
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
83
        quantization: The method used to quantize the model weights. Currently,
84
            we support "awq", "gptq", and "fp8" (experimental).
85
86
87
88
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
89
90
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
91
92
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
93
94
95
96
97
98
99
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
100
101
102
103
104
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
105
106
107
108
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
109
110
111
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
112
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
113
            When a sequence has context length larger than this, we fall back
114
115
116
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
117
118
119
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
120
121
122
        hf_token: The token to use as HTTP bearer authorization for remote files
            . If `True`, will use the token generated when running 
            `huggingface-cli login` (stored in `~/.huggingface`).
123
124
125
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
126
127
128
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
129
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
130
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
131

132
133
134
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
135
    """
136

137
    DEPRECATE_LEGACY: ClassVar[bool] = True
138
139
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

140
141
142
143
144
145
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

146
147
148
149
150
151
152
153
154
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

155
156
157
158
159
160
161
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
162
163
164
    def __init__(
        self,
        model: str,
165
        tokenizer: Optional[str] = None,
166
        tokenizer_mode: str = "auto",
167
        skip_tokenizer_init: bool = False,
168
        trust_remote_code: bool = False,
169
        allowed_local_media_path: str = "",
170
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
171
        dtype: str = "auto",
172
        quantization: Optional[str] = None,
173
        revision: Optional[str] = None,
174
        tokenizer_revision: Optional[str] = None,
175
        seed: Optional[int] = None,
176
        gpu_memory_utilization: float = 0.9,
177
        swap_space: float = 4,
178
        cpu_offload_gb: float = 0,
179
        enforce_eager: Optional[bool] = None,
180
        max_seq_len_to_capture: int = 8192,
181
        disable_custom_all_reduce: bool = False,
182
        disable_async_output_proc: bool = False,
183
        hf_token: Optional[Union[bool, str]] = None,
184
        hf_overrides: Optional[HfOverrides] = None,
185
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
186
187
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
188
        override_pooler_config: Optional[PoolerConfig] = None,
189
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
190
191
        **kwargs,
    ) -> None:
192
193
194
195
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
196
        it defaults to False.
197
198
        '''

199
200
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
201

202
203
204
205
206
207
208
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

209
        if compilation_config is not None:
210
            if isinstance(compilation_config, (int, dict)):
211
212
213
214
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
215
216
217
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
218
        engine_args = EngineArgs(
219
            model=model,
220
            task=task,
221
            tokenizer=tokenizer,
222
            tokenizer_mode=tokenizer_mode,
223
            skip_tokenizer_init=skip_tokenizer_init,
224
            trust_remote_code=trust_remote_code,
225
            allowed_local_media_path=allowed_local_media_path,
226
227
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
228
            quantization=quantization,
229
            revision=revision,
230
            tokenizer_revision=tokenizer_revision,
231
232
233
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
234
            cpu_offload_gb=cpu_offload_gb,
235
            enforce_eager=enforce_eager,
236
            max_seq_len_to_capture=max_seq_len_to_capture,
237
            disable_custom_all_reduce=disable_custom_all_reduce,
238
            disable_async_output_proc=disable_async_output_proc,
239
            hf_token=hf_token,
240
            hf_overrides=hf_overrides,
241
            mm_processor_kwargs=mm_processor_kwargs,
242
            override_pooler_config=override_pooler_config,
243
            compilation_config=compilation_config_instance,
244
245
            **kwargs,
        )
246
247
248
249
250

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
251

252
        self.request_counter = Counter()
253
        self.default_sampling_params: Union[dict[str, Any], None] = None
254

255
256
257
258
259
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
260

261
262
263
264
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
265
            tokenizer_group.tokenizer = tokenizer
266
        else:
267
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
268

269
    def get_default_sampling_params(self) -> SamplingParams:
270
271
272
273
274
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
275
276
        return SamplingParams()

277
278
279
280
281
282
283
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
284
        *,
285
        use_tqdm: bool = True,
286
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
287
288
289
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
290
    ) -> list[RequestOutput]:
291
292
        ...

293
    @overload  # LEGACY: single (prompt + optional token ids)
294
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
295
296
297
298
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
299
300
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
301
        use_tqdm: bool = True,
302
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
303
304
305
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
306
    ) -> list[RequestOutput]:
307
308
309
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
310
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
311
312
    def generate(
        self,
313
        prompts: list[str],
314
        sampling_params: Optional[Union[SamplingParams,
315
316
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
317
        use_tqdm: bool = True,
318
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
319
320
321
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
322
    ) -> list[RequestOutput]:
323
324
325
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
326
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
327
328
329
330
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
331
                                        list[SamplingParams]]] = None,
332
        *,
333
        prompt_token_ids: list[int],
334
        use_tqdm: bool = True,
335
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
336
337
338
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
339
    ) -> list[RequestOutput]:
340
341
342
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
343
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
344
345
    def generate(
        self,
346
        prompts: Optional[list[str]] = None,
347
        sampling_params: Optional[Union[SamplingParams,
348
                                        list[SamplingParams]]] = None,
349
        *,
350
        prompt_token_ids: list[list[int]],
351
        use_tqdm: bool = True,
352
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
353
354
355
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
356
    ) -> list[RequestOutput]:
357
358
359
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
360
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
361
362
363
364
    def generate(
        self,
        prompts: None,
        sampling_params: None,
365
        prompt_token_ids: Union[list[int], list[list[int]]],
366
        use_tqdm: bool = True,
367
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
368
369
370
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
371
    ) -> list[RequestOutput]:
372
373
        ...

nunjunj's avatar
nunjunj committed
374
375
376
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
377
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
378
    )
379
380
    def generate(
        self,
381
        prompts: Union[Union[PromptType, Sequence[PromptType]],
382
                       Optional[Union[str, list[str]]]] = None,
383
384
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
385
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
386
        use_tqdm: bool = True,
387
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
388
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
389
        guided_options_request: Optional[Union[LLMGuidedOptions,
390
                                               GuidedDecodingRequest]] = None,
391
392
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
        """Generates the completions for the input prompts.

395
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
399
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
400
401
402
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
403
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
404
405
406
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
407
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
408
            use_tqdm: Whether to use tqdm to display the progress bar.
409
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
410
            prompt_adapter_request: Prompt Adapter request to use for
411
                generation, if any.
412
413
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
414
415

        Returns:
nunjunj's avatar
nunjunj committed
416
            A list of ``RequestOutput`` objects containing the
417
            generated completions in the same order as the input prompts.
418
419
420
421
422

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
423
        """
424
        runner_type = self.llm_engine.model_config.runner_type
425
        if runner_type not in ["generate", "transcription"]:
426
            messages = [
427
                "LLM.generate() is only supported for (conditional) generation "
428
429
430
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

431
432
433
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
434
                messages.append(
435
436
437
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
438
439

            raise ValueError(" ".join(messages))
440

441
        if prompt_token_ids is not None:
442
            parsed_prompts = self._convert_v1_inputs(
443
                prompts=cast(Optional[Union[str, list[str]]], prompts),
444
445
446
                prompt_token_ids=prompt_token_ids,
            )
        else:
447
448
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
449

450
451
452
453
454
455
456
457
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

458
459
        if sampling_params is None:
            # Use default sampling params.
460
            sampling_params = self.get_default_sampling_params()
461

462
        self._validate_and_add_requests(
463
            prompts=parsed_prompts,
464
465
            params=sampling_params,
            lora_request=lora_request,
466
            prompt_adapter_request=prompt_adapter_request,
467
468
            guided_options=guided_options_request,
            priority=priority)
469

470
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
471
        return self.engine_class.validate_outputs(outputs, RequestOutput)
472

473
    def collective_rpc(self,
474
                       method: Union[str, Callable[..., _R]],
475
                       timeout: Optional[float] = None,
476
477
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
500
501

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
502
503

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
504
        """
505
506
        Run a function directly on the model inside each worker,
        returning the result for each of them.
507
        """
508
509
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
510

511
512
    def beam_search(
        self,
513
        prompts: list[Union[TokensPrompt, TextPrompt]],
514
        params: BeamSearchParams,
515
    ) -> list[BeamSearchOutput]:
516
517
518
519
520
521
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
522
523
            params: The beam search parameters.

524
525
526
527
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

528
529
530
531
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
532
533
534
535
536
537
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
538

539
540
541
542
543
544
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
545
                                            temperature=temperature)
546
        instances: list[BeamSearchInstance] = []
547
548

        for prompt in prompts:
549
550
551
552
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
553
554
555
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
556
            all_beams: list[BeamSearchSequence] = list(
557
558
559
560
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
561
            instance_start_and_end: list[tuple[int, int]] = list(
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
593
                                logprobs=current_beam.logprobs + [logprobs],
594
595
596
597
598
599
600
601
602
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
603
                                      key=sort_beams_key,
604
605
606
607
608
609
610
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
611
                                      key=sort_beams_key,
612
613
614
615
616
617
618
619
620
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
621
622
    def chat(
        self,
623
624
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
625
        sampling_params: Optional[Union[SamplingParams,
626
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
627
628
629
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
630
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
631
        add_generation_prompt: bool = True,
632
        continue_final_message: bool = False,
633
634
635
        tools: Optional[list[dict[str, Any]]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
636
        """
637
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
638

639
640
641
642
643
644
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
645
646

        Args:
647
648
649
650
651
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
652
653
654
655
656
657
658
659
660
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
661
662
663
664
665
666
667
668
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

669
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
670
                to each message.
671
            continue_final_message: If True, continues the final message in
672
673
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
674
675
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
676
677
678
679
680

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
681
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
682

683
684
        # Handle multi and single conversations
        if is_list_of(messages, list):
685
686
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
687
                                    messages)
688
        else:
689
            # messages is list[...]
690
            list_of_messages = [
691
                cast(list[ChatCompletionMessageParam], messages)
692
            ]
693

694
695
696
697
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
698
            tools,
699
700
            chat_template_content_format,
            tokenizer,
701
            trust_remote_code=model_config.trust_remote_code,
702
703
        )

704
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
705
706

        for msgs in list_of_messages:
707
708
709
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
710
            conversation, mm_data = parse_chat_messages(
711
712
713
714
715
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
716

717
            prompt_data: Union[str, list[int]]
718
719
720
721
722
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
723
                    tools=tools,
724
                    add_generation_prompt=add_generation_prompt,
725
                    continue_final_message=continue_final_message,
726
727
728
729
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
730
                    trust_remote_code=model_config.trust_remote_code,
731
732
                    conversation=conversation,
                    chat_template=chat_template,
733
                    tools=tools,
734
                    add_generation_prompt=add_generation_prompt,
735
                    continue_final_message=continue_final_message,
736
737
738
739
740
741
742
743
744
745
746
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

747
748
749
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

750
            prompts.append(prompt)
751

nunjunj's avatar
nunjunj committed
752
        return self.generate(
753
            prompts,
754
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
755
756
757
758
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

759
760
761
762
763
764
765
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
766
        *,
767
        use_tqdm: bool = True,
768
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
769
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
770
    ) -> list[PoolingRequestOutput]:
771
772
        ...

773
    @overload  # LEGACY: single (prompt + optional token ids)
774
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
775
776
777
778
779
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
780
        prompt_token_ids: Optional[list[int]] = None,
781
        use_tqdm: bool = True,
782
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
783
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
784
    ) -> list[PoolingRequestOutput]:
785
        ...
786

787
    @overload  # LEGACY: multi (prompt + optional token ids)
788
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
789
790
    def encode(
        self,
791
        prompts: list[str],
792
        pooling_params: Optional[Union[PoolingParams,
793
                                       Sequence[PoolingParams]]] = None,
794
        prompt_token_ids: Optional[list[list[int]]] = None,
795
        use_tqdm: bool = True,
796
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
797
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
798
    ) -> list[PoolingRequestOutput]:
799
800
801
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
802
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
803
804
805
806
807
808
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
809
        prompt_token_ids: list[int],
810
        use_tqdm: bool = True,
811
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
812
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
813
    ) -> list[PoolingRequestOutput]:
814
815
816
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
817
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
818
819
    def encode(
        self,
820
        prompts: Optional[list[str]] = None,
821
822
823
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
824
        prompt_token_ids: list[list[int]],
825
        use_tqdm: bool = True,
826
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
827
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
828
    ) -> list[PoolingRequestOutput]:
829
830
831
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
832
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
833
834
835
836
    def encode(
        self,
        prompts: None,
        pooling_params: None,
837
        prompt_token_ids: Union[list[int], list[list[int]]],
838
        use_tqdm: bool = True,
839
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
840
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
841
    ) -> list[PoolingRequestOutput]:
842
843
        ...

nunjunj's avatar
nunjunj committed
844
845
846
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
847
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
848
    )
849
850
    def encode(
        self,
851
        prompts: Union[Union[PromptType, Sequence[PromptType]],
852
                       Optional[Union[str, list[str]]]] = None,
853
854
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
855
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
856
        use_tqdm: bool = True,
857
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
858
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
859
    ) -> list[PoolingRequestOutput]:
860
861
        """Apply pooling to the hidden states corresponding to the input
        prompts.
862

863
        This class automatically batches the given prompts, considering
864
865
866
867
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
868
869
870
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
871
872
873
874
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
875
            prompt_adapter_request: Prompt Adapter request to use for
876
                generation, if any.
877
878

        Returns:
879
            A list of ``PoolingRequestOutput`` objects containing the
880
            pooled hidden states in the same order as the input prompts.
881
882
883
884
885

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
886
        """
887
888
889
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
890

891
892
893
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
894
                messages.append(
895
896
897
898
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
899
900

            raise ValueError(" ".join(messages))
901

902
        if prompt_token_ids is not None:
903
            parsed_prompts = self._convert_v1_inputs(
904
                prompts=cast(Optional[Union[str, list[str]]], prompts),
905
906
907
                prompt_token_ids=prompt_token_ids,
            )
        else:
908
909
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
910

911
912
913
914
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

915
        self._validate_and_add_requests(
916
            prompts=parsed_prompts,
917
918
            params=pooling_params,
            lora_request=lora_request,
919
            prompt_adapter_request=prompt_adapter_request,
920
921
        )

922
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
923
        return self.engine_class.validate_outputs(outputs,
924
                                                  PoolingRequestOutput)
925

926
927
928
929
930
931
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
932
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
933
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
934
    ) -> list[EmbeddingRequestOutput]:
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
972
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
973
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
974
    ) -> list[ClassificationRequestOutput]:
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1006
1007
1008
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1009
1010
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1011
1012
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1013
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1014
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1015
    ) -> list[ScoringRequestOutput]:
1016

1017
        encoded_output: list[PoolingRequestOutput] = self.encode(
1018
1019
1020
1021
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1022

1023
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1024
            0:len(text_1)]
1025
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1026
            len(text_1):]
1027
1028
1029
1030

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1031
        scores: list[PoolingRequestOutput] = []
1032

1033
1034
1035
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1036
1037
1038
1039
1040
1041
1042

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1043
        tokenizer: AnyTokenizer,
1044
1045
        text_1: list[str],
        text_2: list[str],
1046
1047
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1048
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1049
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1050
    ) -> list[ScoringRequestOutput]:
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1063
        tokenization_kwargs: dict[str, Any] = {}
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1092
1093
1094
1095
1096
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1097
        *,
1098
1099
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1100
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1101
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1102
    ) -> list[ScoringRequestOutput]:
1103
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1104

1105
1106
1107
1108
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1109
1110
1111
1112
1113
1114
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1115
                case it has to have the same length as the ``text_2`` list
1116
1117
1118
1119
1120
1121
1122
1123
1124
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1125
            A list of ``ScoringRequestOutput`` objects containing the
1126
1127
            generated scores in the same order as the input prompts.
        """
1128
1129
1130
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1131

1132
1133
1134
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1135
                messages.append(
1136
1137
1138
1139
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1140
1141
1142

            raise ValueError(" ".join(messages))

1143
        if self.llm_engine.model_config.task not in ("embed", "score"):
1144
            raise ValueError(
1145
                "Score API is only enabled for `--task embed or --task score`")
1146
1147
1148
1149

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1150
1151
        tokenizer = self.llm_engine.get_tokenizer()

1152
1153
1154
1155
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1156
                                     "supported for scoring")
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1168
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1169
1170
1171
1172

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1173
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1174

1175
        _validate_score_input_lens(input_text_1, input_text_2)
1176

1177
        if self.llm_engine.model_config.is_cross_encoder:
1178
1179
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1180
1181
1182
1183
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1184
1185
1186
1187
1188
1189
1190
1191
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1192

1193
1194
1195
1196
1197
1198
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1199
1200
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1201

1202
1203
1204
1205
1206
1207
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1220
        """
1221
        self.reset_prefix_cache()
1222
1223
        self.llm_engine.sleep(level=level)

1224
    def wake_up(self, tags: Optional[list[str]] = None):
1225
1226
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1237

1238
1239
    # LEGACY
    def _convert_v1_inputs(
1240
        self,
1241
1242
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1243
1244
    ):
        # skip_tokenizer_init is now checked in engine
1245

1246
1247
1248
1249
1250
1251
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1252

1253
        num_requests = None
1254
1255
        if prompts is not None:
            num_requests = len(prompts)
1256
1257
1258
1259
1260
1261
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1262
            num_requests = len(prompt_token_ids)
1263
1264
1265
1266
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1267
        parsed_prompts: list[PromptType] = []
1268
        for i in range(num_requests):
1269
            item: PromptType
1270

1271
            if prompts is not None:
1272
1273
1274
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1275
            else:
1276
                raise AssertionError
1277

1278
            parsed_prompts.append(item)
1279

1280
        return parsed_prompts
1281
1282
1283

    def _validate_and_add_requests(
        self,
1284
        prompts: Union[PromptType, Sequence[PromptType]],
1285
1286
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1287
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1288
        prompt_adapter_request: Optional[PromptAdapterRequest],
1289
        guided_options: Optional[GuidedDecodingRequest] = None,
1290
        priority: Optional[list[int]] = None,
1291
    ) -> None:
1292
1293
1294
1295
1296
1297
1298
1299
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1300
        if isinstance(prompts, (str, dict)):
1301
            # Convert a single prompt to a list.
1302
            prompts = [prompts]
1303

1304
        num_requests = len(prompts)
1305
1306
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1307
                             "must be the same.")
1308
1309
1310
1311
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1312

1313
1314
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1315
                self._add_guided_params(sp, guided_options)
1316
1317
1318

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1319

Zhuohan Li's avatar
Zhuohan Li committed
1320
        # Add requests to the engine.
1321
        for i, prompt in enumerate(prompts):
1322
            self._add_request(
1323
                prompt,
1324
                params[i] if isinstance(params, Sequence) else params,
1325
1326
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1327
                prompt_adapter_request=prompt_adapter_request,
1328
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1329
            )
1330

1331
    def _add_request(
nunjunj's avatar
nunjunj committed
1332
        self,
1333
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1334
        params: Union[SamplingParams, PoolingParams],
1335
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1336
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1337
        priority: int = 0,
1338
1339
    ) -> None:
        request_id = str(next(self.request_counter))
1340
1341
        self.llm_engine.add_request(
            request_id,
1342
            prompt,
1343
1344
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1345
            prompt_adapter_request=prompt_adapter_request,
1346
            priority=priority,
nunjunj's avatar
nunjunj committed
1347
        )
1348

1349
    def _add_guided_params(
1350
1351
1352
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1353
1354
1355
1356
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1357
            raise ValueError("Cannot set both guided_options_request and "
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1368
1369
        return params

1370
    def _run_engine(
1371
            self, *, use_tqdm: bool
1372
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1373
1374
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1375
            num_requests = self.llm_engine.get_num_unfinished_requests()
1376
1377
1378
1379
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1380
1381
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1382
            )
1383

Zhuohan Li's avatar
Zhuohan Li committed
1384
        # Run the engine.
1385
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1386
1387
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1388
1389
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1390
            for output in step_outputs:
1391
                if output.finished:
1392
1393
                    outputs.append(output)
                    if use_tqdm:
1394
1395
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1396
                            n = len(output.outputs)
1397
                            assert output.prompt_token_ids is not None
1398
                            total_in_toks += len(output.prompt_token_ids) * n
1399
1400
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1401
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1402
1403
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1404
1405
1406
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1407
                            pbar.update(n)
1408
1409
                        else:
                            pbar.update(1)
1410

1411
1412
        if use_tqdm:
            pbar.close()
1413
1414
1415
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1416
        return sorted(outputs, key=lambda x: int(x.request_id))