llm.py 62.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
33
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
34
35
36
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
37
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
39
40
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
41
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
43
from vllm.usage.usage_lib import UsageContext
44
45
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
46

47
48
logger = init_logger(__name__)

49
50
_R = TypeVar("_R", default=Any)

51
52

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
57
58
59
60
61
62
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
63
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
64
65
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
66
67
68
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
69
70
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
71
72
73
74
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
82
        quantization: The method used to quantize the model weights. Currently,
83
            we support "awq", "gptq", and "fp8" (experimental).
84
85
86
87
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
88
89
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
90
91
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
92
93
94
95
96
97
98
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
99
100
101
102
103
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
104
105
106
107
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
108
109
110
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
111
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
112
            When a sequence has context length larger than this, we fall back
113
114
115
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
116
117
118
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
119
120
121
        hf_token: The token to use as HTTP bearer authorization for remote files
            . If `True`, will use the token generated when running 
            `huggingface-cli login` (stored in `~/.huggingface`).
122
123
124
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
125
126
127
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
128
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
129
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
130

131
132
133
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
134
    """
135

136
    DEPRECATE_LEGACY: ClassVar[bool] = True
137
138
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

139
140
141
142
143
144
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

145
146
147
148
149
150
151
152
153
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

154
155
156
157
158
159
160
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
161
162
163
    def __init__(
        self,
        model: str,
164
        tokenizer: Optional[str] = None,
165
        tokenizer_mode: str = "auto",
166
        skip_tokenizer_init: bool = False,
167
        trust_remote_code: bool = False,
168
        allowed_local_media_path: str = "",
169
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
170
        dtype: str = "auto",
171
        quantization: Optional[str] = None,
172
        revision: Optional[str] = None,
173
        tokenizer_revision: Optional[str] = None,
174
        seed: Optional[int] = None,
175
        gpu_memory_utilization: float = 0.9,
176
        swap_space: float = 4,
177
        cpu_offload_gb: float = 0,
178
        enforce_eager: Optional[bool] = None,
179
        max_seq_len_to_capture: int = 8192,
180
        disable_custom_all_reduce: bool = False,
181
        disable_async_output_proc: bool = False,
182
        hf_token: Optional[Union[bool, str]] = None,
183
        hf_overrides: Optional[HfOverrides] = None,
184
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
185
186
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
187
        override_pooler_config: Optional[PoolerConfig] = None,
188
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
189
190
        **kwargs,
    ) -> None:
191
192
193
194
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
195
        it defaults to False.
196
197
        '''

198
199
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
200

201
202
203
204
205
206
207
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

208
        if compilation_config is not None:
209
            if isinstance(compilation_config, (int, dict)):
210
211
212
213
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
214
215
216
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
217
        engine_args = EngineArgs(
218
            model=model,
219
            task=task,
220
            tokenizer=tokenizer,
221
            tokenizer_mode=tokenizer_mode,
222
            skip_tokenizer_init=skip_tokenizer_init,
223
            trust_remote_code=trust_remote_code,
224
            allowed_local_media_path=allowed_local_media_path,
225
226
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
227
            quantization=quantization,
228
            revision=revision,
229
            tokenizer_revision=tokenizer_revision,
230
231
232
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
233
            cpu_offload_gb=cpu_offload_gb,
234
            enforce_eager=enforce_eager,
235
            max_seq_len_to_capture=max_seq_len_to_capture,
236
            disable_custom_all_reduce=disable_custom_all_reduce,
237
            disable_async_output_proc=disable_async_output_proc,
238
            hf_token=hf_token,
239
            hf_overrides=hf_overrides,
240
            mm_processor_kwargs=mm_processor_kwargs,
241
            override_pooler_config=override_pooler_config,
242
            compilation_config=compilation_config_instance,
243
244
            **kwargs,
        )
245
246
247
248
249

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
250

251
        self.request_counter = Counter()
252
        self.default_sampling_params: Union[dict[str, Any], None] = None
253

254
    def get_tokenizer(self) -> AnyTokenizer:
255
        return self.llm_engine.get_tokenizer_group().tokenizer
256
257

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
258
        tokenizer_group = self.llm_engine.get_tokenizer_group()
259

260
261
262
263
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
264
            tokenizer_group.tokenizer = tokenizer
265
        else:
266
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
267

268
    def get_default_sampling_params(self) -> SamplingParams:
269
270
271
272
273
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
274
275
        return SamplingParams()

276
277
278
279
280
281
282
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
283
        *,
284
        use_tqdm: bool = True,
285
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
286
287
288
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
289
    ) -> list[RequestOutput]:
290
291
        ...

292
    @overload  # LEGACY: single (prompt + optional token ids)
293
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
294
295
296
297
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
298
299
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
300
        use_tqdm: bool = True,
301
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
302
303
304
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
305
    ) -> list[RequestOutput]:
306
307
308
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
309
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
310
311
    def generate(
        self,
312
        prompts: list[str],
313
        sampling_params: Optional[Union[SamplingParams,
314
315
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
316
        use_tqdm: bool = True,
317
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
318
319
320
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
321
    ) -> list[RequestOutput]:
322
323
324
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
325
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
326
327
328
329
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
330
                                        list[SamplingParams]]] = None,
331
        *,
332
        prompt_token_ids: list[int],
333
        use_tqdm: bool = True,
334
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
335
336
337
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
338
    ) -> list[RequestOutput]:
339
340
341
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
342
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
343
344
    def generate(
        self,
345
        prompts: Optional[list[str]] = None,
346
        sampling_params: Optional[Union[SamplingParams,
347
                                        list[SamplingParams]]] = None,
348
        *,
349
        prompt_token_ids: list[list[int]],
350
        use_tqdm: bool = True,
351
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
352
353
354
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
355
    ) -> list[RequestOutput]:
356
357
358
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
359
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
360
361
362
363
    def generate(
        self,
        prompts: None,
        sampling_params: None,
364
        prompt_token_ids: Union[list[int], list[list[int]]],
365
        use_tqdm: bool = True,
366
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
367
368
369
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
370
    ) -> list[RequestOutput]:
371
372
        ...

nunjunj's avatar
nunjunj committed
373
374
375
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
376
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
377
    )
378
379
    def generate(
        self,
380
        prompts: Union[Union[PromptType, Sequence[PromptType]],
381
                       Optional[Union[str, list[str]]]] = None,
382
383
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
384
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
385
        use_tqdm: bool = True,
386
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
387
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
388
        guided_options_request: Optional[Union[LLMGuidedOptions,
389
                                               GuidedDecodingRequest]] = None,
390
391
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
        """Generates the completions for the input prompts.

394
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
395
396
397
398
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
399
400
401
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
402
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
403
404
405
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
406
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
407
            use_tqdm: Whether to use tqdm to display the progress bar.
408
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
409
            prompt_adapter_request: Prompt Adapter request to use for
410
                generation, if any.
411
412
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414

        Returns:
nunjunj's avatar
nunjunj committed
415
            A list of ``RequestOutput`` objects containing the
416
            generated completions in the same order as the input prompts.
417
418
419
420
421

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
422
        """
423
        runner_type = self.llm_engine.model_config.runner_type
424
        if runner_type not in ["generate", "transcription"]:
425
            messages = [
426
                "LLM.generate() is only supported for (conditional) generation "
427
428
429
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

430
431
432
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
433
                messages.append(
434
435
436
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
437
438

            raise ValueError(" ".join(messages))
439

440
        if prompt_token_ids is not None:
441
            parsed_prompts = self._convert_v1_inputs(
442
                prompts=cast(Optional[Union[str, list[str]]], prompts),
443
444
445
                prompt_token_ids=prompt_token_ids,
            )
        else:
446
447
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
448

449
450
451
452
453
454
455
456
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

457
458
        if sampling_params is None:
            # Use default sampling params.
459
            sampling_params = self.get_default_sampling_params()
460

461
        self._validate_and_add_requests(
462
            prompts=parsed_prompts,
463
464
            params=sampling_params,
            lora_request=lora_request,
465
            prompt_adapter_request=prompt_adapter_request,
466
467
            guided_options=guided_options_request,
            priority=priority)
468

469
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
470
        return self.engine_class.validate_outputs(outputs, RequestOutput)
471

472
    def collective_rpc(self,
473
                       method: Union[str, Callable[..., _R]],
474
                       timeout: Optional[float] = None,
475
476
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
499
500

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
501
502

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
503
        """
504
505
        Run a function directly on the model inside each worker,
        returning the result for each of them.
506
        """
507
508
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
509

510
511
    def beam_search(
        self,
512
        prompts: list[Union[TokensPrompt, TextPrompt]],
513
        params: BeamSearchParams,
514
    ) -> list[BeamSearchOutput]:
515
516
517
518
519
520
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
521
            params: The beam search parameters.
522
        """
523
524
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
525
526
527
528
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
529
530
531
532
533
534
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
535

536
537
538
539
540
541
542
543
544
545
546
547
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
548

549
550
551
552
553
554
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
555
                                            temperature=temperature)
556
        instances: list[BeamSearchInstance] = []
557
558

        for prompt in prompts:
559
560
561
562
563
564
565
566
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

567
568
569
570
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
571
572
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
573
574

        for _ in range(max_tokens):
575
            all_beams: list[BeamSearchSequence] = list(
576
577
578
579
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
580
            instance_start_and_end: list[tuple[int, int]] = list(
581
582
583
584
585
586
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
587
                create_tokens_prompt_from_beam(beam) for beam in all_beams
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
611
                                logprobs=current_beam.logprobs + [logprobs],
612
                                cum_logprob=current_beam.cum_logprob +
613
614
615
616
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
617
618
619
620
621
622
623

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
624
                                      key=sort_beams_key,
625
626
627
628
629
630
631
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
632
                                      key=sort_beams_key,
633
634
635
636
637
638
639
640
641
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
642
643
    def chat(
        self,
644
645
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
646
        sampling_params: Optional[Union[SamplingParams,
647
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
648
649
650
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
651
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
652
        add_generation_prompt: bool = True,
653
        continue_final_message: bool = False,
654
655
656
        tools: Optional[list[dict[str, Any]]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
657
        """
658
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
659

660
661
662
663
664
665
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
666
667

        Args:
668
669
670
671
672
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
673
674
675
676
677
678
679
680
681
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
682
683
684
685
686
687
688
689
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

690
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
691
                to each message.
692
            continue_final_message: If True, continues the final message in
693
694
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
695
696
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
697
698
699
700
701

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
702
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
703

704
705
        # Handle multi and single conversations
        if is_list_of(messages, list):
706
707
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
708
                                    messages)
709
        else:
710
            # messages is list[...]
711
            list_of_messages = [
712
                cast(list[ChatCompletionMessageParam], messages)
713
            ]
714

715
716
717
718
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
719
            tools,
720
721
            chat_template_content_format,
            tokenizer,
722
            trust_remote_code=model_config.trust_remote_code,
723
724
        )

725
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
726
727

        for msgs in list_of_messages:
728
729
730
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
731
            conversation, mm_data = parse_chat_messages(
732
733
734
735
736
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
737

738
            prompt_data: Union[str, list[int]]
739
740
741
742
743
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
744
                    tools=tools,
745
                    add_generation_prompt=add_generation_prompt,
746
                    continue_final_message=continue_final_message,
747
748
749
750
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
751
                    trust_remote_code=model_config.trust_remote_code,
752
753
                    conversation=conversation,
                    chat_template=chat_template,
754
                    tools=tools,
755
                    add_generation_prompt=add_generation_prompt,
756
                    continue_final_message=continue_final_message,
757
758
759
760
761
762
763
764
765
766
767
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

768
769
770
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

771
            prompts.append(prompt)
772

nunjunj's avatar
nunjunj committed
773
        return self.generate(
774
            prompts,
775
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
776
777
778
779
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

780
781
782
783
784
785
786
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
787
        *,
788
        use_tqdm: bool = True,
789
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
790
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
791
    ) -> list[PoolingRequestOutput]:
792
793
        ...

794
    @overload  # LEGACY: single (prompt + optional token ids)
795
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
796
797
798
799
800
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
801
        prompt_token_ids: Optional[list[int]] = None,
802
        use_tqdm: bool = True,
803
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
804
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
805
    ) -> list[PoolingRequestOutput]:
806
        ...
807

808
    @overload  # LEGACY: multi (prompt + optional token ids)
809
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
810
811
    def encode(
        self,
812
        prompts: list[str],
813
        pooling_params: Optional[Union[PoolingParams,
814
                                       Sequence[PoolingParams]]] = None,
815
        prompt_token_ids: Optional[list[list[int]]] = None,
816
        use_tqdm: bool = True,
817
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
818
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
819
    ) -> list[PoolingRequestOutput]:
820
821
822
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
823
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
824
825
826
827
828
829
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
830
        prompt_token_ids: list[int],
831
        use_tqdm: bool = True,
832
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
833
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
834
    ) -> list[PoolingRequestOutput]:
835
836
837
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
838
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
839
840
    def encode(
        self,
841
        prompts: Optional[list[str]] = None,
842
843
844
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
845
        prompt_token_ids: list[list[int]],
846
        use_tqdm: bool = True,
847
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
848
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
849
    ) -> list[PoolingRequestOutput]:
850
851
852
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
853
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
854
855
856
857
    def encode(
        self,
        prompts: None,
        pooling_params: None,
858
        prompt_token_ids: Union[list[int], list[list[int]]],
859
        use_tqdm: bool = True,
860
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
861
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
862
    ) -> list[PoolingRequestOutput]:
863
864
        ...

nunjunj's avatar
nunjunj committed
865
866
867
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
868
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
869
    )
870
871
    def encode(
        self,
872
        prompts: Union[Union[PromptType, Sequence[PromptType]],
873
                       Optional[Union[str, list[str]]]] = None,
874
875
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
876
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
877
        use_tqdm: bool = True,
878
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
879
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
880
    ) -> list[PoolingRequestOutput]:
881
882
        """Apply pooling to the hidden states corresponding to the input
        prompts.
883

884
        This class automatically batches the given prompts, considering
885
886
887
888
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
889
890
891
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
892
893
894
895
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
896
            prompt_adapter_request: Prompt Adapter request to use for
897
                generation, if any.
898
899

        Returns:
900
            A list of ``PoolingRequestOutput`` objects containing the
901
            pooled hidden states in the same order as the input prompts.
902
903
904
905
906

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
907
        """
908
909
910
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
911

912
913
914
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
915
                messages.append(
916
917
918
919
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
920
921

            raise ValueError(" ".join(messages))
922

923
        if prompt_token_ids is not None:
924
            parsed_prompts = self._convert_v1_inputs(
925
                prompts=cast(Optional[Union[str, list[str]]], prompts),
926
927
928
                prompt_token_ids=prompt_token_ids,
            )
        else:
929
930
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
931

932
933
934
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
935
936
937
938
939
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
940

941
        self._validate_and_add_requests(
942
            prompts=parsed_prompts,
943
944
            params=pooling_params,
            lora_request=lora_request,
945
            prompt_adapter_request=prompt_adapter_request,
946
947
        )

948
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
949
        return self.engine_class.validate_outputs(outputs,
950
                                                  PoolingRequestOutput)
951

952
953
954
955
956
957
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
958
959
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
960
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
961
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
962
    ) -> list[EmbeddingRequestOutput]:
963
964
965
966
967
968
969
970
971
972
973
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
974
975
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
991
                            pooling_params=pooling_params,
992
993
994
995
996
997
998
999
1000
1001
1002
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1003
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1004
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1005
    ) -> list[ClassificationRequestOutput]:
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1037
1038
1039
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1040
1041
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1042
1043
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1044
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1045
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1046
    ) -> list[ScoringRequestOutput]:
1047

1048
        encoded_output: list[PoolingRequestOutput] = self.encode(
1049
1050
1051
1052
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1053

1054
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1055
            0:len(text_1)]
1056
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1057
            len(text_1):]
1058
1059
1060
1061

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1062
        scores: list[PoolingRequestOutput] = []
1063

1064
1065
1066
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1067
1068
1069
1070
1071
1072
1073

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1074
        tokenizer: AnyTokenizer,
1075
1076
        text_1: list[str],
        text_2: list[str],
1077
1078
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1079
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1080
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1081
    ) -> list[ScoringRequestOutput]:
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1094
        tokenization_kwargs: dict[str, Any] = {}
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1123
1124
1125
1126
1127
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1128
        *,
1129
1130
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1131
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1132
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1133
    ) -> list[ScoringRequestOutput]:
1134
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1135

1136
1137
1138
1139
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1140
1141
1142
1143
1144
1145
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1146
                case it has to have the same length as the ``text_2`` list
1147
1148
1149
1150
1151
1152
1153
1154
1155
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1156
            A list of ``ScoringRequestOutput`` objects containing the
1157
1158
            generated scores in the same order as the input prompts.
        """
1159
1160
1161
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1162

1163
1164
1165
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1166
                messages.append(
1167
1168
1169
1170
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1171
1172
1173

            raise ValueError(" ".join(messages))

1174
        if self.llm_engine.model_config.task not in ("embed", "score"):
1175
            raise ValueError(
1176
                "Score API is only enabled for `--task embed or --task score`")
1177
1178
1179
1180

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1181
1182
        tokenizer = self.llm_engine.get_tokenizer()

1183
1184
1185
1186
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1187
                                     "supported for scoring")
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1199
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1200
1201
1202
1203

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1204
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1205

1206
        _validate_score_input_lens(input_text_1, input_text_2)
1207

1208
        if self.llm_engine.model_config.is_cross_encoder:
1209
1210
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1211
1212
1213
1214
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1215
1216
1217
1218
1219
1220
1221
1222
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1223

1224
1225
1226
1227
1228
1229
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1230
1231
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1232

1233
1234
1235
1236
1237
1238
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1251
        """
1252
        self.reset_prefix_cache()
1253
1254
        self.llm_engine.sleep(level=level)

1255
    def wake_up(self, tags: Optional[list[str]] = None):
1256
1257
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1268

1269
1270
    # LEGACY
    def _convert_v1_inputs(
1271
        self,
1272
1273
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1274
1275
    ):
        # skip_tokenizer_init is now checked in engine
1276

1277
1278
1279
1280
1281
1282
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1283

1284
        num_requests = None
1285
1286
        if prompts is not None:
            num_requests = len(prompts)
1287
1288
1289
1290
1291
1292
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1293
            num_requests = len(prompt_token_ids)
1294
1295
1296
1297
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1298
        parsed_prompts: list[PromptType] = []
1299
        for i in range(num_requests):
1300
            item: PromptType
1301

1302
            if prompts is not None:
1303
1304
1305
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1306
            else:
1307
                raise AssertionError
1308

1309
            parsed_prompts.append(item)
1310

1311
        return parsed_prompts
1312
1313
1314

    def _validate_and_add_requests(
        self,
1315
        prompts: Union[PromptType, Sequence[PromptType]],
1316
1317
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1318
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1319
        prompt_adapter_request: Optional[PromptAdapterRequest],
1320
        guided_options: Optional[GuidedDecodingRequest] = None,
1321
        priority: Optional[list[int]] = None,
1322
    ) -> None:
1323
1324
1325
1326
1327
1328
1329
1330
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1331
        if isinstance(prompts, (str, dict)):
1332
            # Convert a single prompt to a list.
1333
            prompts = [prompts]
1334

1335
        num_requests = len(prompts)
1336
1337
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1338
                             "must be the same.")
1339
1340
1341
1342
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1343

1344
1345
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1346
                self._add_guided_params(sp, guided_options)
1347
1348
1349

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1350

Zhuohan Li's avatar
Zhuohan Li committed
1351
        # Add requests to the engine.
1352
        for i, prompt in enumerate(prompts):
1353
            self._add_request(
1354
                prompt,
1355
                params[i] if isinstance(params, Sequence) else params,
1356
1357
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1358
                prompt_adapter_request=prompt_adapter_request,
1359
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1360
            )
1361

1362
    def _add_request(
nunjunj's avatar
nunjunj committed
1363
        self,
1364
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1365
        params: Union[SamplingParams, PoolingParams],
1366
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1367
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1368
        priority: int = 0,
1369
1370
    ) -> None:
        request_id = str(next(self.request_counter))
1371
1372
        self.llm_engine.add_request(
            request_id,
1373
            prompt,
1374
1375
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1376
            prompt_adapter_request=prompt_adapter_request,
1377
            priority=priority,
nunjunj's avatar
nunjunj committed
1378
        )
1379

1380
    def _add_guided_params(
1381
1382
1383
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1384
1385
1386
1387
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1388
            raise ValueError("Cannot set both guided_options_request and "
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1399
1400
        return params

1401
    def _run_engine(
1402
            self, *, use_tqdm: bool
1403
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1404
1405
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1406
            num_requests = self.llm_engine.get_num_unfinished_requests()
1407
1408
1409
1410
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1411
1412
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1413
            )
1414

Zhuohan Li's avatar
Zhuohan Li committed
1415
        # Run the engine.
1416
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1417
1418
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1419
1420
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1421
            for output in step_outputs:
1422
                if output.finished:
1423
1424
                    outputs.append(output)
                    if use_tqdm:
1425
1426
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1427
                            n = len(output.outputs)
1428
                            assert output.prompt_token_ids is not None
1429
                            total_in_toks += len(output.prompt_token_ids) * n
1430
1431
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1432
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1433
1434
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1435
1436
1437
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1438
                            pbar.update(n)
1439
1440
                        else:
                            pbar.update(1)
1441

1442
1443
        if use_tqdm:
            pbar.close()
1444
1445
1446
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1447
        return sorted(outputs, key=lambda x: int(x.request_id))