llm.py 64.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
17
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
18
19
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
20
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
21
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
22
                                         ChatTemplateContentFormatOption,
23
24
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
25
26
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
27
28
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
29
from vllm.entrypoints.utils import _validate_truncation_size
30
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
31
from vllm.inputs.parse import parse_and_batch_prompt
32
from vllm.logger import init_logger
33
from vllm.lora.request import LoRARequest
34
35
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
36
from vllm.model_executor.layers.quantization import QuantizationMethods
37
38
39
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
40
from vllm.pooling_params import PoolingParams
41
from vllm.prompt_adapter.request import PromptAdapterRequest
42
43
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
44
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
45
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
46
from vllm.usage.usage_lib import UsageContext
47
48
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
49

50
51
logger = init_logger(__name__)

52
53
_R = TypeVar("_R", default=Any)

54
55

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
59
60
61
62
63
64
65
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
66
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
67
68
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
69
70
71
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
72
73
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
74
75
76
77
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
85
        quantization: The method used to quantize the model weights. Currently,
86
            we support "awq", "gptq", and "fp8" (experimental).
87
88
89
90
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
91
92
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
93
94
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
95
96
97
98
99
100
101
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
102
103
104
105
106
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
107
108
109
110
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
111
112
113
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
114
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
115
            When a sequence has context length larger than this, we fall back
116
117
118
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
119
        disable_custom_all_reduce: See {class}`~vllm.config.ParallelConfig`
120
121
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
122
        hf_token: The token to use as HTTP bearer authorization for remote files
123
            . If `True`, will use the token generated when running
124
            `huggingface-cli login` (stored in `~/.huggingface`).
125
126
127
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
128
129
130
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
131
132
        **kwargs: Arguments for {class}`~vllm.EngineArgs`. (See
            {ref}`engine-args`)
nunjunj's avatar
nunjunj committed
133

134
135
136
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the {class}`~vllm.AsyncLLMEngine` class instead.
137
    """
138

139
    DEPRECATE_LEGACY: ClassVar[bool] = True
140
141
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

142
143
144
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
145
    {meth}`LLM.__init__`.
146
147
    """

148
149
150
151
152
153
154
155
156
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

157
158
159
160
161
162
163
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
164
165
166
    def __init__(
        self,
        model: str,
167
        tokenizer: Optional[str] = None,
168
        tokenizer_mode: TokenizerMode = "auto",
169
        skip_tokenizer_init: bool = False,
170
        trust_remote_code: bool = False,
171
        allowed_local_media_path: str = "",
172
        tensor_parallel_size: int = 1,
173
174
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
175
        revision: Optional[str] = None,
176
        tokenizer_revision: Optional[str] = None,
177
        seed: Optional[int] = None,
178
        gpu_memory_utilization: float = 0.9,
179
        swap_space: float = 4,
180
        cpu_offload_gb: float = 0,
181
        enforce_eager: bool = False,
182
        max_seq_len_to_capture: int = 8192,
183
        disable_custom_all_reduce: bool = False,
184
        disable_async_output_proc: bool = False,
185
        hf_token: Optional[Union[bool, str]] = None,
186
        hf_overrides: Optional[HfOverrides] = None,
187
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
188
189
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
190
        override_pooler_config: Optional[PoolerConfig] = None,
191
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
192
193
        **kwargs,
    ) -> None:
194
        """LLM constructor."""
195

196
197
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
198

199
200
201
202
203
204
205
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

206
        if compilation_config is not None:
207
208
209
210
211
212
213
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
214
215
            else:
                compilation_config_instance = compilation_config
216
217
218
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
219
        engine_args = EngineArgs(
220
            model=model,
221
            task=task,
222
            tokenizer=tokenizer,
223
            tokenizer_mode=tokenizer_mode,
224
            skip_tokenizer_init=skip_tokenizer_init,
225
            trust_remote_code=trust_remote_code,
226
            allowed_local_media_path=allowed_local_media_path,
227
228
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
229
            quantization=quantization,
230
            revision=revision,
231
            tokenizer_revision=tokenizer_revision,
232
233
234
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
235
            cpu_offload_gb=cpu_offload_gb,
236
            enforce_eager=enforce_eager,
237
            max_seq_len_to_capture=max_seq_len_to_capture,
238
            disable_custom_all_reduce=disable_custom_all_reduce,
239
            disable_async_output_proc=disable_async_output_proc,
240
            hf_token=hf_token,
241
            hf_overrides=hf_overrides,
242
            mm_processor_kwargs=mm_processor_kwargs,
243
            override_pooler_config=override_pooler_config,
244
            compilation_config=compilation_config_instance,
245
246
            **kwargs,
        )
247
248
249
250
251

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
252

253
        self.request_counter = Counter()
254
        self.default_sampling_params: Union[dict[str, Any], None] = None
255

256
257
258
259
260
261
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
262
263

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
264
        tokenizer_group = self.llm_engine.get_tokenizer_group()
265

266
267
268
269
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
270
            tokenizer_group.tokenizer = tokenizer
271
        else:
272
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
273

274
    def get_default_sampling_params(self) -> SamplingParams:
275
276
277
278
279
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
280
281
        return SamplingParams()

282
283
284
285
286
287
288
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
289
        *,
290
        use_tqdm: bool = True,
291
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
292
293
294
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
295
    ) -> list[RequestOutput]:
296
297
        ...

298
    @overload  # LEGACY: single (prompt + optional token ids)
299
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
300
301
302
303
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
304
305
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
306
        use_tqdm: bool = True,
307
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
308
309
310
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
311
    ) -> list[RequestOutput]:
312
313
314
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
315
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
316
317
    def generate(
        self,
318
        prompts: list[str],
319
        sampling_params: Optional[Union[SamplingParams,
320
321
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
322
        use_tqdm: bool = True,
323
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
324
325
326
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
327
    ) -> list[RequestOutput]:
328
329
330
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
331
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
332
333
334
335
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
336
                                        list[SamplingParams]]] = None,
337
        *,
338
        prompt_token_ids: list[int],
339
        use_tqdm: bool = True,
340
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
341
342
343
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
344
    ) -> list[RequestOutput]:
345
346
347
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
348
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
349
350
    def generate(
        self,
351
        prompts: Optional[list[str]] = None,
352
        sampling_params: Optional[Union[SamplingParams,
353
                                        list[SamplingParams]]] = None,
354
        *,
355
        prompt_token_ids: list[list[int]],
356
        use_tqdm: bool = True,
357
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
358
359
360
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
361
    ) -> list[RequestOutput]:
362
363
364
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
365
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
366
367
368
369
    def generate(
        self,
        prompts: None,
        sampling_params: None,
370
        prompt_token_ids: Union[list[int], list[list[int]]],
371
        use_tqdm: bool = True,
372
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
373
374
375
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
376
    ) -> list[RequestOutput]:
377
378
        ...

nunjunj's avatar
nunjunj committed
379
380
381
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
382
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
383
    )
384
385
    def generate(
        self,
386
        prompts: Union[Union[PromptType, Sequence[PromptType]],
387
                       Optional[Union[str, list[str]]]] = None,
388
389
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
390
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
391
        use_tqdm: bool = True,
392
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
393
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
394
        guided_options_request: Optional[Union[LLMGuidedOptions,
395
                                               GuidedDecodingRequest]] = None,
396
397
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
398
399
        """Generates the completions for the input prompts.

400
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
401
402
403
404
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
405
            prompts: The prompts to the LLM. You may pass a sequence of prompts
406
                for batch inference. See {class}`~vllm.inputs.PromptType`
407
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
408
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
409
410
411
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
412
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
413
            use_tqdm: Whether to use tqdm to display the progress bar.
414
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
415
            prompt_adapter_request: Prompt Adapter request to use for
416
                generation, if any.
417
418
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
419
420

        Returns:
421
            A list of `RequestOutput` objects containing the
422
            generated completions in the same order as the input prompts.
423

424
425
426
427
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
428
        """
429
        runner_type = self.llm_engine.model_config.runner_type
430
        if runner_type not in ["generate", "transcription"]:
431
            messages = [
432
                "LLM.generate() is only supported for (conditional) generation "
433
434
435
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

436
437
438
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
439
                messages.append(
440
441
442
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
443
444

            raise ValueError(" ".join(messages))
445

446
        if prompt_token_ids is not None:
447
            parsed_prompts = self._convert_v1_inputs(
448
                prompts=cast(Optional[Union[str, list[str]]], prompts),
449
450
451
                prompt_token_ids=prompt_token_ids,
            )
        else:
452
453
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
454

455
456
457
458
459
460
461
462
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

463
464
        if sampling_params is None:
            # Use default sampling params.
465
            sampling_params = self.get_default_sampling_params()
466

467
        self._validate_and_add_requests(
468
            prompts=parsed_prompts,
469
            params=sampling_params,
470
            use_tqdm=use_tqdm,
471
            lora_request=lora_request,
472
            prompt_adapter_request=prompt_adapter_request,
473
            guided_options=guided_options_request,
474
475
            priority=priority,
        )
476

477
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
478
        return self.engine_class.validate_outputs(outputs, RequestOutput)
479

480
    def collective_rpc(self,
481
                       method: Union[str, Callable[..., _R]],
482
                       timeout: Optional[float] = None,
483
484
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
485
486
487
488
489
490
491
492
493
494
495
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
496
                {exc}`TimeoutError` on timeout. `None` means wait indefinitely.
497
498
499
500
501
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
502

503
504
505
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
506
        """
507
508

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
509
510

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
511
        """
512
513
        Run a function directly on the model inside each worker,
        returning the result for each of them.
514
        """
515
516
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
517

518
519
    def beam_search(
        self,
520
        prompts: list[Union[TokensPrompt, TextPrompt]],
521
        params: BeamSearchParams,
522
    ) -> list[BeamSearchOutput]:
523
524
525
526
527
528
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
529
            params: The beam search parameters.
530
        """
531
532
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
533
534
535
536
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
537
538
539
540
541
542
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
543

544
545
546
547
548
549
550
551
552
553
554
555
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
556

557
558
559
560
561
562
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
563
                                            temperature=temperature)
564
        instances: list[BeamSearchInstance] = []
565
566

        for prompt in prompts:
567
568
569
570
571
572
573
574
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

575
576
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
577
578
579
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
580

581
582
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
583
584

        for _ in range(max_tokens):
585
            all_beams: list[BeamSearchSequence] = list(
586
587
588
589
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
590
            instance_start_and_end: list[tuple[int, int]] = list(
591
592
593
594
595
596
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
597
                create_tokens_prompt_from_beam(beam) for beam in all_beams
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
621
                                logprobs=current_beam.logprobs + [logprobs],
622
                                cum_logprob=current_beam.cum_logprob +
623
624
625
626
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
627
628
629
630
631
632
633

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
634
                                      key=sort_beams_key,
635
636
637
638
639
640
641
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
642
                                      key=sort_beams_key,
643
644
645
646
647
648
649
650
651
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
652
653
    def chat(
        self,
654
655
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
656
        sampling_params: Optional[Union[SamplingParams,
657
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
658
659
660
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
661
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
662
        add_generation_prompt: bool = True,
663
        continue_final_message: bool = False,
664
        tools: Optional[list[dict[str, Any]]] = None,
665
        chat_template_kwargs: Optional[dict[str, Any]] = None,
666
667
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
668
        """
669
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
670

671
        The chat conversation is converted into a text prompt using the
672
        tokenizer and calls the {meth}`generate` method to generate the
673
674
675
676
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
677
678

        Args:
679
680
681
682
683
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
684
685
686
687
688
689
690
691
692
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
693
694
695
696
697
698
699
700
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

701
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
702
                to each message.
703
            continue_final_message: If True, continues the final message in
704
705
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
706
707
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
708
709
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
710
711
712
713
714

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
715
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
716

717
718
        # Handle multi and single conversations
        if is_list_of(messages, list):
719
720
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
721
                                    messages)
722
        else:
723
            # messages is list[...]
724
            list_of_messages = [
725
                cast(list[ChatCompletionMessageParam], messages)
726
            ]
727

728
        tokenizer = self.get_tokenizer(lora_request)
729
730
731
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
732
            tools,
733
734
            chat_template_content_format,
            tokenizer,
735
            model_config=model_config,
736
737
        )

738
739
740
741
742
743
744
745
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

746
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
747
748

        for msgs in list_of_messages:
749
750
751
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
752
            conversation, mm_data = parse_chat_messages(
753
754
755
756
757
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
758
759

            if isinstance(tokenizer, MistralTokenizer):
760
                prompt_token_ids = apply_mistral_chat_template(
761
762
                    tokenizer,
                    messages=msgs,
763
                    **_chat_template_kwargs,
764
765
                )
            else:
766
                prompt_str = apply_hf_chat_template(
767
                    tokenizer=tokenizer,
768
                    conversation=conversation,
769
                    model_config=model_config,
770
                    **_chat_template_kwargs,
771
                )
772
773
774
775
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
776

777
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
778
779
780
781

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

782
783
784
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

785
            prompts.append(prompt)
786

nunjunj's avatar
nunjunj committed
787
        return self.generate(
788
            prompts,
789
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
790
791
792
793
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

794
795
796
797
798
799
800
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
801
        *,
802
        truncate_prompt_tokens: Optional[int] = None,
803
        use_tqdm: bool = True,
804
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
805
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
806
    ) -> list[PoolingRequestOutput]:
807
808
        ...

809
    @overload  # LEGACY: single (prompt + optional token ids)
810
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
811
812
813
814
815
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
816
        prompt_token_ids: Optional[list[int]] = None,
817
        truncate_prompt_tokens: Optional[int] = None,
818
        use_tqdm: bool = True,
819
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
820
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
821
    ) -> list[PoolingRequestOutput]:
822
        ...
823

824
    @overload  # LEGACY: multi (prompt + optional token ids)
825
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
826
827
    def encode(
        self,
828
        prompts: list[str],
829
        pooling_params: Optional[Union[PoolingParams,
830
                                       Sequence[PoolingParams]]] = None,
831
        prompt_token_ids: Optional[list[list[int]]] = None,
832
        truncate_prompt_tokens: Optional[int] = None,
833
        use_tqdm: bool = True,
834
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
835
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
836
    ) -> list[PoolingRequestOutput]:
837
838
839
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
840
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
841
842
843
844
845
846
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
847
        prompt_token_ids: list[int],
848
        truncate_prompt_tokens: Optional[int] = None,
849
        use_tqdm: bool = True,
850
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
851
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
852
    ) -> list[PoolingRequestOutput]:
853
854
855
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
856
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
857
858
    def encode(
        self,
859
        prompts: Optional[list[str]] = None,
860
861
862
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
863
        prompt_token_ids: list[list[int]],
864
        truncate_prompt_tokens: Optional[int] = None,
865
        use_tqdm: bool = True,
866
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
867
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
868
    ) -> list[PoolingRequestOutput]:
869
870
871
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
872
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
873
874
875
876
    def encode(
        self,
        prompts: None,
        pooling_params: None,
877
        prompt_token_ids: Union[list[int], list[list[int]]],
878
        truncate_prompt_tokens: Optional[int] = None,
879
        use_tqdm: bool = True,
880
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
881
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
882
    ) -> list[PoolingRequestOutput]:
883
884
        ...

nunjunj's avatar
nunjunj committed
885
886
887
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
888
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
889
    )
890
891
    def encode(
        self,
892
        prompts: Union[Union[PromptType, Sequence[PromptType]],
893
                       Optional[Union[str, list[str]]]] = None,
894
895
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
896
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
897
        truncate_prompt_tokens: Optional[int] = None,
898
        use_tqdm: bool = True,
899
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
900
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
901
    ) -> list[PoolingRequestOutput]:
902
903
        """Apply pooling to the hidden states corresponding to the input
        prompts.
904

905
        This class automatically batches the given prompts, considering
906
907
908
909
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
910
            prompts: The prompts to the LLM. You may pass a sequence of prompts
911
                for batch inference. See {class}`~vllm.inputs.PromptType`
912
                for more details about the format of each prompts.
913
914
915
916
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
917
            prompt_adapter_request: Prompt Adapter request to use for
918
                generation, if any.
919
920

        Returns:
921
            A list of `PoolingRequestOutput` objects containing the
922
            pooled hidden states in the same order as the input prompts.
923

924
925
926
927
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
928
        """
929
930
931
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
932

933
934
935
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
936
                messages.append(
937
938
939
940
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
941
942

            raise ValueError(" ".join(messages))
943

944
        if prompt_token_ids is not None:
945
            parsed_prompts = self._convert_v1_inputs(
946
                prompts=cast(Optional[Union[str, list[str]]], prompts),
947
948
949
                prompt_token_ids=prompt_token_ids,
            )
        else:
950
951
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
952

953
954
955
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
956
957
958
959
960
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
961

962
963
964
965
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

966
        self._validate_and_add_requests(
967
            prompts=parsed_prompts,
968
            params=pooling_params,
969
            use_tqdm=use_tqdm,
970
            lora_request=lora_request,
971
            tokenization_kwargs=tokenization_kwargs,
972
            prompt_adapter_request=prompt_adapter_request,
973
974
        )

975
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
976
        return self.engine_class.validate_outputs(outputs,
977
                                                  PoolingRequestOutput)
978

979
980
981
982
983
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
984
        truncate_prompt_tokens: Optional[int] = None,
985
        use_tqdm: bool = True,
986
987
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
988
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
989
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
990
    ) -> list[EmbeddingRequestOutput]:
991
992
993
994
995
996
997
998
999
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1000
                for batch inference. See {class}`~vllm.inputs.PromptType`
1001
                for more details about the format of each prompts.
1002
1003
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1018
                            truncate_prompt_tokens=truncate_prompt_tokens,
1019
                            use_tqdm=use_tqdm,
1020
                            pooling_params=pooling_params,
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1032
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1033
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1034
    ) -> list[ClassificationRequestOutput]:
1035
1036
1037
1038
1039
1040
1041
1042
1043
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1044
                for batch inference. See {class}`~vllm.inputs.PromptType`
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1066
1067
1068
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1069
1070
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1071
1072
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1073
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1074
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1075
    ) -> list[ScoringRequestOutput]:
1076

1077
        encoded_output: list[PoolingRequestOutput] = self.encode(
1078
            text_1 + text_2,
1079
            truncate_prompt_tokens=truncate_prompt_tokens,
1080
1081
1082
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1083

1084
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1085
            0:len(text_1)]
1086
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1087
            len(text_1):]
1088
1089
1090
1091

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1092
1093
1094
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1095
1096
1097
1098
1099
1100
1101

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1102
        tokenizer: AnyTokenizer,
1103
1104
        text_1: list[str],
        text_2: list[str],
1105
1106
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1107
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1108
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1109
    ) -> list[ScoringRequestOutput]:
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1122
        tokenization_kwargs: dict[str, Any] = {}
1123
1124
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1140
            use_tqdm=use_tqdm,
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1151
1152
1153
1154
1155
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1156
        *,
1157
1158
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1159
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1160
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1161
    ) -> list[ScoringRequestOutput]:
1162
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1163

1164
1165
1166
1167
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1168
1169
1170
1171
1172
1173
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1174
                case it has to have the same length as the ``text_2`` list
1175
            text_2: The texts to pair with the query to form the input
1176
                to the LLM. See {class}`~vllm.inputs.PromptType` for
1177
1178
1179
1180
1181
1182
1183
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1184
            A list of ``ScoringRequestOutput`` objects containing the
1185
1186
            generated scores in the same order as the input prompts.
        """
1187
1188
1189
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1190

1191
1192
1193
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1194
                messages.append(
1195
1196
1197
1198
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1199
1200
1201

            raise ValueError(" ".join(messages))

1202
        if self.llm_engine.model_config.task not in ("embed", "score"):
1203
            raise ValueError(
1204
                "Score API is only enabled for `--task embed or --task score`")
1205
1206
1207
1208

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1209
1210
        tokenizer = self.llm_engine.get_tokenizer()

1211
1212
1213
1214
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1215
                                     "supported for scoring")
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1227
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1228
1229
1230
1231

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1232
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1233

1234
        _validate_score_input_lens(input_text_1, input_text_2)
1235

1236
        if self.llm_engine.model_config.is_cross_encoder:
1237
1238
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1239
1240
1241
1242
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1243
1244
1245
1246
1247
1248
1249
1250
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1251

1252
1253
1254
1255
1256
1257
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1258
1259
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1260

1261
1262
1263
1264
1265
1266
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1279
        """
1280
        self.reset_prefix_cache()
1281
1282
        self.llm_engine.sleep(level=level)

1283
    def wake_up(self, tags: Optional[list[str]] = None):
1284
        """
1285
        Wake up the engine from sleep mode. See the {meth}`sleep` method
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1296

1297
1298
    # LEGACY
    def _convert_v1_inputs(
1299
        self,
1300
1301
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1302
1303
    ):
        # skip_tokenizer_init is now checked in engine
1304

1305
1306
1307
1308
1309
1310
1311
1312
1313
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1314
1315
1316
1317
1318
1319
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1320
1321
        if prompts is not None:
            num_requests = len(prompts)
1322
        elif prompt_token_ids is not None:
1323
            num_requests = len(prompt_token_ids)
1324
        parsed_prompts: list[PromptType] = []
1325
        for i in range(num_requests):
1326
            item: PromptType
1327

1328
            if prompts is not None:
1329
1330
1331
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1332
            else:
1333
                raise AssertionError
1334

1335
            parsed_prompts.append(item)
1336

1337
        return parsed_prompts
1338
1339
1340

    def _validate_and_add_requests(
        self,
1341
        prompts: Union[PromptType, Sequence[PromptType]],
1342
1343
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1344
1345
        *,
        use_tqdm: bool,
1346
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1347
        prompt_adapter_request: Optional[PromptAdapterRequest],
1348
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1349
        guided_options: Optional[GuidedDecodingRequest] = None,
1350
        priority: Optional[list[int]] = None,
1351
    ) -> None:
1352
1353
1354
1355
1356
1357
1358
1359
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1360
        if isinstance(prompts, (str, dict)):
1361
            # Convert a single prompt to a list.
1362
            prompts = [prompts]
1363

1364
        num_requests = len(prompts)
1365
1366
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1367
                             "must be the same.")
1368
1369
1370
1371
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1372

1373
1374
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1375
                self._add_guided_params(sp, guided_options)
1376
1377
1378

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1379

Zhuohan Li's avatar
Zhuohan Li committed
1380
        # Add requests to the engine.
1381
1382
1383
1384
1385
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1386
            self._add_request(
1387
                prompt,
1388
                params[i] if isinstance(params, Sequence) else params,
1389
                tokenization_kwargs=tokenization_kwargs,
1390
1391
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1392
                prompt_adapter_request=prompt_adapter_request,
1393
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1394
            )
1395

1396
    def _add_request(
nunjunj's avatar
nunjunj committed
1397
        self,
1398
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1399
        params: Union[SamplingParams, PoolingParams],
1400
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1401
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1402
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1403
        priority: int = 0,
1404
1405
    ) -> None:
        request_id = str(next(self.request_counter))
1406
1407
        self.llm_engine.add_request(
            request_id,
1408
            prompt,
1409
1410
            params,
            lora_request=lora_request,
1411
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1412
            prompt_adapter_request=prompt_adapter_request,
1413
            priority=priority,
nunjunj's avatar
nunjunj committed
1414
        )
1415

1416
    def _add_guided_params(
1417
1418
1419
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1420
1421
1422
1423
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1424
            raise ValueError("Cannot set both guided_options_request and "
1425
1426
1427
1428
1429
1430
1431
1432
1433
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1434
1435
1436
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1437
1438
        return params

1439
    def _run_engine(
1440
            self, *, use_tqdm: bool
1441
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1442
1443
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1444
            num_requests = self.llm_engine.get_num_unfinished_requests()
1445
1446
1447
1448
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1449
1450
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1451
            )
1452

Zhuohan Li's avatar
Zhuohan Li committed
1453
        # Run the engine.
1454
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1455
1456
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1457
1458
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1459
            for output in step_outputs:
1460
                if output.finished:
1461
1462
                    outputs.append(output)
                    if use_tqdm:
1463
1464
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1465
                            n = len(output.outputs)
1466
                            assert output.prompt_token_ids is not None
1467
                            total_in_toks += len(output.prompt_token_ids) * n
1468
1469
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1470
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1471
1472
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1473
1474
1475
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1476
                            pbar.update(n)
1477
1478
                        else:
                            pbar.update(1)
1479

1480
1481
        if use_tqdm:
            pbar.close()
1482
1483
1484
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1485
        return sorted(outputs, key=lambda x: int(x.request_id))