llm.py 60 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
from vllm import envs
15
16
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
17
from vllm.config import CompilationConfig
18
19
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
20
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
21
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
22
                                         ChatTemplateContentFormatOption,
23
24
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
25
26
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
27
28
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
29
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
30
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
31
from vllm.logger import init_logger
32
from vllm.lora.request import LoRARequest
33
34
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
35
36
37
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
38
from vllm.pooling_params import PoolingParams
39
from vllm.prompt_adapter.request import PromptAdapterRequest
40
41
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
42
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
43
44
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
45
from vllm.usage.usage_lib import UsageContext
46
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
47

48
49
logger = init_logger(__name__)

50
51
_R = TypeVar("_R", default=Any)

52
53

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
54
55
56
57
58
59
60
61
62
63
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
64
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
65
66
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
67
68
69
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
70
71
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
72
73
74
75
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
79
80
81
82
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
83
        quantization: The method used to quantize the model weights. Currently,
84
            we support "awq", "gptq", and "fp8" (experimental).
85
86
87
88
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
89
90
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
91
92
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
93
94
95
96
97
98
99
100
101
102
103
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
104
105
106
107
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
108
109
110
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
111
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
112
            When a sequence has context length larger than this, we fall back
113
114
115
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
116
117
118
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
119
120
121
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
122
123
124
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
125
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
126
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
127

128
129
130
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
131
    """
132

133
    DEPRECATE_LEGACY: ClassVar[bool] = True
134
135
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

136
137
138
139
140
141
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

142
143
144
145
146
147
148
149
150
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

151
152
153
154
155
156
157
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
158
159
160
    def __init__(
        self,
        model: str,
161
        tokenizer: Optional[str] = None,
162
        tokenizer_mode: str = "auto",
163
        skip_tokenizer_init: bool = False,
164
        trust_remote_code: bool = False,
165
        allowed_local_media_path: str = "",
166
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
167
        dtype: str = "auto",
168
        quantization: Optional[str] = None,
169
        revision: Optional[str] = None,
170
        tokenizer_revision: Optional[str] = None,
171
172
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
173
        swap_space: float = 4,
174
        cpu_offload_gb: float = 0,
175
        enforce_eager: Optional[bool] = None,
176
        max_seq_len_to_capture: int = 8192,
177
        disable_custom_all_reduce: bool = False,
178
        disable_async_output_proc: bool = False,
179
        hf_overrides: Optional[HfOverrides] = None,
180
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
181
182
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
183
        override_pooler_config: Optional[PoolerConfig] = None,
184
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
185
186
        **kwargs,
    ) -> None:
187
188
189
190
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
191
        it defaults to False.
192
193
        '''

194
195
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
196

197
198
199
200
201
202
203
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

204
        if compilation_config is not None:
205
            if isinstance(compilation_config, (int, dict)):
206
207
208
209
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
210
211
212
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
213
        engine_args = EngineArgs(
214
            model=model,
215
            task=task,
216
            tokenizer=tokenizer,
217
            tokenizer_mode=tokenizer_mode,
218
            skip_tokenizer_init=skip_tokenizer_init,
219
            trust_remote_code=trust_remote_code,
220
            allowed_local_media_path=allowed_local_media_path,
221
222
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
223
            quantization=quantization,
224
            revision=revision,
225
            tokenizer_revision=tokenizer_revision,
226
227
228
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
229
            cpu_offload_gb=cpu_offload_gb,
230
            enforce_eager=enforce_eager,
231
            max_seq_len_to_capture=max_seq_len_to_capture,
232
            disable_custom_all_reduce=disable_custom_all_reduce,
233
            disable_async_output_proc=disable_async_output_proc,
234
            hf_overrides=hf_overrides,
235
            mm_processor_kwargs=mm_processor_kwargs,
236
            override_pooler_config=override_pooler_config,
237
            compilation_config=compilation_config_instance,
238
239
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
240
241
242
243
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
244
            engine_args, usage_context=UsageContext.LLM_CLASS)
245

246
        self.request_counter = Counter()
247
        self.default_sampling_params: Union[dict[str, Any], None] = None
248

Joe Runde's avatar
Joe Runde committed
249
    @staticmethod
250
    def get_engine_class() -> type[LLMEngine]:
Joe Runde's avatar
Joe Runde committed
251
252
253
254
255
256
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

257
258
259
260
261
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
262

263
264
265
266
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
267
            tokenizer_group.tokenizer = tokenizer
268
        else:
269
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
270

271
    def get_default_sampling_params(self) -> SamplingParams:
272
273
274
275
276
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
277
278
        return SamplingParams()

279
280
281
282
283
284
285
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
286
        *,
287
        use_tqdm: bool = True,
288
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
289
290
291
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
292
    ) -> list[RequestOutput]:
293
294
        ...

295
    @overload  # LEGACY: single (prompt + optional token ids)
296
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
297
298
299
300
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
301
302
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
303
        use_tqdm: bool = True,
304
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
305
306
307
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
308
    ) -> list[RequestOutput]:
309
310
311
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
312
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
313
314
    def generate(
        self,
315
        prompts: list[str],
316
        sampling_params: Optional[Union[SamplingParams,
317
318
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
319
        use_tqdm: bool = True,
320
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
321
322
323
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
324
    ) -> list[RequestOutput]:
325
326
327
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
328
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
329
330
331
332
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
333
                                        list[SamplingParams]]] = None,
334
        *,
335
        prompt_token_ids: list[int],
336
        use_tqdm: bool = True,
337
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
338
339
340
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
341
    ) -> list[RequestOutput]:
342
343
344
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
345
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
346
347
    def generate(
        self,
348
        prompts: Optional[list[str]] = None,
349
        sampling_params: Optional[Union[SamplingParams,
350
                                        list[SamplingParams]]] = None,
351
        *,
352
        prompt_token_ids: list[list[int]],
353
        use_tqdm: bool = True,
354
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
355
356
357
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
358
    ) -> list[RequestOutput]:
359
360
361
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
362
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
363
364
365
366
    def generate(
        self,
        prompts: None,
        sampling_params: None,
367
        prompt_token_ids: Union[list[int], list[list[int]]],
368
        use_tqdm: bool = True,
369
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
370
371
372
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
373
    ) -> list[RequestOutput]:
374
375
        ...

nunjunj's avatar
nunjunj committed
376
377
378
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
379
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
380
    )
381
382
    def generate(
        self,
383
        prompts: Union[Union[PromptType, Sequence[PromptType]],
384
                       Optional[Union[str, list[str]]]] = None,
385
386
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
387
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
388
        use_tqdm: bool = True,
389
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
390
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
391
        guided_options_request: Optional[Union[LLMGuidedOptions,
392
                                               GuidedDecodingRequest]] = None,
393
394
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
        """Generates the completions for the input prompts.

397
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
398
399
400
401
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
402
403
404
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
405
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
406
407
408
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
409
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
410
            use_tqdm: Whether to use tqdm to display the progress bar.
411
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
412
            prompt_adapter_request: Prompt Adapter request to use for
413
                generation, if any.
414
415
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
416
417

        Returns:
nunjunj's avatar
nunjunj committed
418
            A list of ``RequestOutput`` objects containing the
419
            generated completions in the same order as the input prompts.
420
421
422
423
424

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
425
        """
426
        runner_type = self.llm_engine.model_config.runner_type
427
        if runner_type not in ["generate", "transcription"]:
428
            messages = [
429
                "LLM.generate() is only supported for (conditional) generation "
430
431
432
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

433
434
435
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
436
                messages.append(
437
438
439
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
440
441

            raise ValueError(" ".join(messages))
442

443
        if prompt_token_ids is not None:
444
            parsed_prompts = self._convert_v1_inputs(
445
                prompts=cast(Optional[Union[str, list[str]]], prompts),
446
447
448
                prompt_token_ids=prompt_token_ids,
            )
        else:
449
450
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
451

452
453
454
455
456
457
458
459
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

460
461
        if sampling_params is None:
            # Use default sampling params.
462
            sampling_params = self.get_default_sampling_params()
463

464
        self._validate_and_add_requests(
465
            prompts=parsed_prompts,
466
467
            params=sampling_params,
            lora_request=lora_request,
468
            prompt_adapter_request=prompt_adapter_request,
469
470
            guided_options=guided_options_request,
            priority=priority)
471

472
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
473
        return self.engine_class.validate_outputs(outputs, RequestOutput)
474

475
    def collective_rpc(self,
476
                       method: Union[str, Callable[..., _R]],
477
                       timeout: Optional[float] = None,
478
479
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
        executor = self.llm_engine.model_executor
        return executor.collective_rpc(method, timeout, args, kwargs)

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
506
        """
507
508
        Run a function directly on the model inside each worker,
        returning the result for each of them.
509
        """
510
511
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
512

513
514
    def beam_search(
        self,
515
        prompts: list[Union[TokensPrompt, TextPrompt]],
516
        params: BeamSearchParams,
517
    ) -> list[BeamSearchOutput]:
518
519
520
521
522
523
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
524
525
            params: The beam search parameters.

526
527
528
529
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

530
531
532
533
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
534
535
536
537
538
539
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
540

541
542
543
544
545
546
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
547
                                            temperature=temperature)
548
        instances: list[BeamSearchInstance] = []
549
550

        for prompt in prompts:
551
552
553
554
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
555
556
557
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
558
            all_beams: list[BeamSearchSequence] = list(
559
560
561
562
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
563
            instance_start_and_end: list[tuple[int, int]] = list(
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
595
                                logprobs=current_beam.logprobs + [logprobs],
596
597
598
599
600
601
602
603
604
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
605
                                      key=sort_beams_key,
606
607
608
609
610
611
612
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
613
                                      key=sort_beams_key,
614
615
616
617
618
619
620
621
622
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
623
624
    def chat(
        self,
625
626
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
627
        sampling_params: Optional[Union[SamplingParams,
628
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
629
630
631
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
632
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
633
        add_generation_prompt: bool = True,
634
        continue_final_message: bool = False,
635
636
637
        tools: Optional[list[dict[str, Any]]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
638
        """
639
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
640

641
642
643
644
645
646
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
647
648

        Args:
649
650
651
652
653
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
654
655
656
657
658
659
660
661
662
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
663
664
665
666
667
668
669
670
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

671
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
672
                to each message.
673
            continue_final_message: If True, continues the final message in
674
675
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
676
677
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
678
679
680
681
682

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
683
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
684

685
686
        # Handle multi and single conversations
        if is_list_of(messages, list):
687
688
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
689
                                    messages)
690
        else:
691
            # messages is list[...]
692
            list_of_messages = [
693
                cast(list[ChatCompletionMessageParam], messages)
694
            ]
695

696
697
698
699
700
701
702
703
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

704
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
705
706

        for msgs in list_of_messages:
707
708
709
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
710
            conversation, mm_data = parse_chat_messages(
711
712
713
714
715
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
716

717
            prompt_data: Union[str, list[int]]
718
719
720
721
722
723
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
724
                    continue_final_message=continue_final_message,
725
726
727
728
729
730
731
732
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
733
                    continue_final_message=continue_final_message,
734
735
736
737
738
739
740
741
742
743
744
745
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

746
747
748
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

749
            prompts.append(prompt)
750

nunjunj's avatar
nunjunj committed
751
        return self.generate(
752
            prompts,
753
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
754
755
756
757
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

758
759
760
761
762
763
764
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
765
        *,
766
        use_tqdm: bool = True,
767
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
768
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
769
    ) -> list[PoolingRequestOutput]:
770
771
        ...

772
    @overload  # LEGACY: single (prompt + optional token ids)
773
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
774
775
776
777
778
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
779
        prompt_token_ids: Optional[list[int]] = None,
780
        use_tqdm: bool = True,
781
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
782
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
783
    ) -> list[PoolingRequestOutput]:
784
        ...
785

786
    @overload  # LEGACY: multi (prompt + optional token ids)
787
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
788
789
    def encode(
        self,
790
        prompts: list[str],
791
        pooling_params: Optional[Union[PoolingParams,
792
                                       Sequence[PoolingParams]]] = None,
793
        prompt_token_ids: Optional[list[list[int]]] = None,
794
        use_tqdm: bool = True,
795
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
796
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
797
    ) -> list[PoolingRequestOutput]:
798
799
800
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
801
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
802
803
804
805
806
807
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
808
        prompt_token_ids: list[int],
809
        use_tqdm: bool = True,
810
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
811
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
812
    ) -> list[PoolingRequestOutput]:
813
814
815
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
816
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
817
818
    def encode(
        self,
819
        prompts: Optional[list[str]] = None,
820
821
822
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
823
        prompt_token_ids: list[list[int]],
824
        use_tqdm: bool = True,
825
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
826
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
827
    ) -> list[PoolingRequestOutput]:
828
829
830
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
831
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
832
833
834
835
    def encode(
        self,
        prompts: None,
        pooling_params: None,
836
        prompt_token_ids: Union[list[int], list[list[int]]],
837
        use_tqdm: bool = True,
838
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
839
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
840
    ) -> list[PoolingRequestOutput]:
841
842
        ...

nunjunj's avatar
nunjunj committed
843
844
845
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
846
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
847
    )
848
849
    def encode(
        self,
850
        prompts: Union[Union[PromptType, Sequence[PromptType]],
851
                       Optional[Union[str, list[str]]]] = None,
852
853
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
854
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
855
        use_tqdm: bool = True,
856
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
857
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
858
    ) -> list[PoolingRequestOutput]:
859
860
        """Apply pooling to the hidden states corresponding to the input
        prompts.
861

862
        This class automatically batches the given prompts, considering
863
864
865
866
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
867
868
869
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
870
871
872
873
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
874
            prompt_adapter_request: Prompt Adapter request to use for
875
                generation, if any.
876
877

        Returns:
878
            A list of ``PoolingRequestOutput`` objects containing the
879
            pooled hidden states in the same order as the input prompts.
880
881
882
883
884

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
885
        """
886
887
888
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
889

890
891
892
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
893
                messages.append(
894
895
896
897
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
898
899

            raise ValueError(" ".join(messages))
900

901
        if prompt_token_ids is not None:
902
            parsed_prompts = self._convert_v1_inputs(
903
                prompts=cast(Optional[Union[str, list[str]]], prompts),
904
905
906
                prompt_token_ids=prompt_token_ids,
            )
        else:
907
908
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
909

910
911
912
913
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

914
        self._validate_and_add_requests(
915
            prompts=parsed_prompts,
916
917
            params=pooling_params,
            lora_request=lora_request,
918
            prompt_adapter_request=prompt_adapter_request,
919
920
        )

921
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
922
        return self.engine_class.validate_outputs(outputs,
923
                                                  PoolingRequestOutput)
924

925
926
927
928
929
930
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
931
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
932
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
933
    ) -> list[EmbeddingRequestOutput]:
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
971
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
972
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
973
    ) -> list[ClassificationRequestOutput]:
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1005
1006
1007
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1008
1009
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1010
1011
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1012
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1013
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1014
    ) -> list[ScoringRequestOutput]:
1015

1016
        encoded_output: list[PoolingRequestOutput] = self.encode(
1017
1018
1019
1020
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1021

1022
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1023
            0:len(text_1)]
1024
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1025
            len(text_1):]
1026
1027
1028
1029

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1030
        scores: list[PoolingRequestOutput] = []
1031

1032
1033
1034
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1035
1036
1037
1038
1039
1040
1041

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1042
        tokenizer: AnyTokenizer,
1043
1044
        text_1: list[str],
        text_2: list[str],
1045
1046
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1047
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1048
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1049
    ) -> list[ScoringRequestOutput]:
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1062
        tokenization_kwargs: dict[str, Any] = {}
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1091
1092
1093
1094
1095
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1096
        *,
1097
1098
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1099
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1100
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1101
    ) -> list[ScoringRequestOutput]:
1102
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1103

1104
1105
1106
1107
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1108
1109
1110
1111
1112
1113
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1114
                case it has to have the same length as the ``text_2`` list
1115
1116
1117
1118
1119
1120
1121
1122
1123
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1124
            A list of ``ScoringRequestOutput`` objects containing the
1125
1126
            generated scores in the same order as the input prompts.
        """
1127
1128
1129
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1130

1131
1132
1133
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1134
                messages.append(
1135
1136
1137
1138
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1139
1140
1141

            raise ValueError(" ".join(messages))

1142
        if self.llm_engine.model_config.task not in ("embed", "score"):
1143
            raise ValueError(
1144
                "Score API is only enabled for `--task embed or --task score`")
1145
1146
1147
1148

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1149
1150
        tokenizer = self.llm_engine.get_tokenizer()

1151
1152
1153
1154
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1155
                                     "supported for scoring")
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1167
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1168
1169
1170
1171

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1172
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1173

1174
        _validate_score_input_lens(input_text_1, input_text_2)
1175

1176
        if self.llm_engine.model_config.is_cross_encoder:
1177
1178
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1179
1180
1181
1182
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1183
1184
1185
1186
1187
1188
1189
1190
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1191

1192
1193
1194
1195
1196
1197
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1198
1199
1200
    def reset_prefix_cache(self) -> bool:
        return self.llm_engine.reset_prefix_cache()

1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

        :param level: The sleep level. Level 1 sleep will offload the model 
            weights and discard the kv cache. The content of kv cache is 
            forgotten. Level 1 sleep is good for sleeping and waking up the 
            engine to run the same model again. The model weights are backed 
            up in CPU memory. Please make sure there's enough CPU memory to 
            store the model weights. Level 2 sleep will discard both the model 
            weights and the kv cache. The content of both the model weights 
            and kv cache is forgotten. Level 2 sleep is good for sleeping and 
            waking up the engine to run a different model or update the model, 
            where previous model weights are not needed. It reduces CPU memory 
            pressure.
        """
1219
        self.reset_prefix_cache()
1220
1221
1222
        self.llm_engine.sleep(level=level)

    def wake_up(self):
1223
1224
1225
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
        for more details."""
1226
1227
        self.llm_engine.wake_up()

1228
1229
    # LEGACY
    def _convert_v1_inputs(
1230
        self,
1231
1232
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1233
1234
    ):
        # skip_tokenizer_init is now checked in engine
1235

1236
1237
1238
1239
1240
1241
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1242

1243
        num_requests = None
1244
1245
        if prompts is not None:
            num_requests = len(prompts)
1246
1247
1248
1249
1250
1251
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1252
            num_requests = len(prompt_token_ids)
1253
1254
1255
1256
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1257
        parsed_prompts: list[PromptType] = []
1258
        for i in range(num_requests):
1259
            item: PromptType
1260

1261
            if prompts is not None:
1262
1263
1264
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1265
            else:
1266
                raise AssertionError
1267

1268
            parsed_prompts.append(item)
1269

1270
        return parsed_prompts
1271
1272
1273

    def _validate_and_add_requests(
        self,
1274
        prompts: Union[PromptType, Sequence[PromptType]],
1275
1276
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1277
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1278
        prompt_adapter_request: Optional[PromptAdapterRequest],
1279
        guided_options: Optional[GuidedDecodingRequest] = None,
1280
        priority: Optional[list[int]] = None,
1281
    ) -> None:
1282
1283
1284
1285
1286
1287
1288
1289
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1290
        if isinstance(prompts, (str, dict)):
1291
            # Convert a single prompt to a list.
1292
            prompts = [prompts]
1293

1294
        num_requests = len(prompts)
1295
1296
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1297
                             "must be the same.")
1298
1299
1300
1301
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1302

1303
1304
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1305
                self._add_guided_params(sp, guided_options)
1306
1307
1308

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1309

Zhuohan Li's avatar
Zhuohan Li committed
1310
        # Add requests to the engine.
1311
        for i, prompt in enumerate(prompts):
1312
            self._add_request(
1313
                prompt,
1314
                params[i] if isinstance(params, Sequence) else params,
1315
1316
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1317
                prompt_adapter_request=prompt_adapter_request,
1318
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1319
            )
1320

1321
    def _add_request(
nunjunj's avatar
nunjunj committed
1322
        self,
1323
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1324
        params: Union[SamplingParams, PoolingParams],
1325
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1326
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1327
        priority: int = 0,
1328
1329
    ) -> None:
        request_id = str(next(self.request_counter))
1330
1331
        self.llm_engine.add_request(
            request_id,
1332
            prompt,
1333
1334
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1335
            prompt_adapter_request=prompt_adapter_request,
1336
            priority=priority,
nunjunj's avatar
nunjunj committed
1337
        )
1338

1339
    def _add_guided_params(
1340
1341
1342
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1343
1344
1345
1346
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1347
            raise ValueError("Cannot set both guided_options_request and "
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1358
1359
        return params

1360
    def _run_engine(
1361
            self, *, use_tqdm: bool
1362
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1363
1364
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1365
            num_requests = self.llm_engine.get_num_unfinished_requests()
1366
1367
1368
1369
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1370
1371
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1372
            )
1373

Zhuohan Li's avatar
Zhuohan Li committed
1374
        # Run the engine.
1375
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1376
1377
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1378
1379
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1380
            for output in step_outputs:
1381
                if output.finished:
1382
1383
                    outputs.append(output)
                    if use_tqdm:
1384
1385
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1386
                            assert output.prompt_token_ids is not None
1387
1388
1389
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1390
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1391
1392
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1393
1394
1395
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1396
                        pbar.update(1)
1397

1398
1399
        if use_tqdm:
            pbar.close()
1400
1401
1402
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1403
        return sorted(outputs, key=lambda x: int(x.request_id))