llm.py 62.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
33
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
34
35
36
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
37
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
39
40
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
41
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
43
from vllm.usage.usage_lib import UsageContext
44
45
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
46
import vllm.envs as envs
lizhigong's avatar
lizhigong committed
47
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
48

49
50
logger = init_logger(__name__)

51
52
_R = TypeVar("_R", default=Any)

53
54

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
59
60
61
62
63
64
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
65
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
66
67
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
68
69
70
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
71
72
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
73
74
75
76
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
77
78
79
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
80
81
82
83
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
84
        quantization: The method used to quantize the model weights. Currently,
85
            we support "awq", "gptq", and "fp8" (experimental).
86
87
88
89
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
90
91
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
92
93
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
94
95
96
97
98
99
100
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
101
102
103
104
105
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
106
107
108
109
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
110
111
112
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
113
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
114
            When a sequence has context length larger than this, we fall back
115
116
117
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
118
119
120
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
121
        hf_token: The token to use as HTTP bearer authorization for remote files
122
            . If `True`, will use the token generated when running
123
            `huggingface-cli login` (stored in `~/.huggingface`).
124
125
126
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
127
128
129
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
130
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
131
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
132

133
134
135
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
136
    """
137

138
    DEPRECATE_LEGACY: ClassVar[bool] = True
139
140
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

141
142
143
144
145
146
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

147
148
149
150
151
152
153
154
155
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

156
157
158
159
160
161
162
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
163
164
165
    def __init__(
        self,
        model: str,
166
        tokenizer: Optional[str] = None,
167
168
        #need change mode as "cpm" for 9g tokenizer
        # tokenizer_mode: str = "cpm",
169
        tokenizer_mode: str = "auto",
170
        skip_tokenizer_init: bool = False,
171
        trust_remote_code: bool = False,
172
        allowed_local_media_path: str = "",
173
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
174
        dtype: str = "auto",
175
        quantization: Optional[str] = None,
176
        revision: Optional[str] = None,
177
        tokenizer_revision: Optional[str] = None,
178
        seed: Optional[int] = None,
179
        gpu_memory_utilization: float = 0.9,
180
        swap_space: float = 4,
181
        cpu_offload_gb: float = 0,
182
        enforce_eager: Optional[bool] = None,
183
        max_seq_len_to_capture: int = 8192,
184
        disable_custom_all_reduce: bool = False,
185
        disable_async_output_proc: bool = False,
186
        hf_token: Optional[Union[bool, str]] = None,
187
        hf_overrides: Optional[HfOverrides] = None,
188
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
189
190
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
191
        override_pooler_config: Optional[PoolerConfig] = None,
192
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
193
194
        **kwargs,
    ) -> None:
195
196
197
198
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
199
        it defaults to False.
200
201
        '''

202
203
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
204

205
206
207
208
209
210
211
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

212
        if compilation_config is not None:
213
            if isinstance(compilation_config, (int, dict)):
214
215
216
217
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
218
219
220
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
221
        engine_args = EngineArgs(
222
            model=model,
223
            task=task,
224
            tokenizer=tokenizer,
225
            tokenizer_mode=tokenizer_mode,
226
            skip_tokenizer_init=skip_tokenizer_init,
227
            trust_remote_code=trust_remote_code,
228
            allowed_local_media_path=allowed_local_media_path,
229
230
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
231
            quantization=quantization,
232
            revision=revision,
233
            tokenizer_revision=tokenizer_revision,
234
235
236
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
237
            cpu_offload_gb=cpu_offload_gb,
238
            enforce_eager=enforce_eager,
239
            max_seq_len_to_capture=max_seq_len_to_capture,
240
            disable_custom_all_reduce=disable_custom_all_reduce,
241
            disable_async_output_proc=disable_async_output_proc,
242
            hf_token=hf_token,
243
            hf_overrides=hf_overrides,
244
            mm_processor_kwargs=mm_processor_kwargs,
245
            override_pooler_config=override_pooler_config,
246
            compilation_config=compilation_config_instance,
247
248
            **kwargs,
        )
249
250

        # Create the Engine (autoselects V0 vs V1)
251
        if envs.VLLM_ZERO_OVERHEAD:
lizhigong's avatar
lizhigong committed
252
253
254
255
256
            self.llm_engine = ZeroOverheadEngine.from_engine_args(
                engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        else:
            self.llm_engine = LLMEngine.from_engine_args(
                engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
257
        self.engine_class = type(self.llm_engine)
258

259
        self.request_counter = Counter()
260
        self.default_sampling_params: Union[dict[str, Any], None] = None
261

262
263
264
265
266
267
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
268
269

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
270
        tokenizer_group = self.llm_engine.get_tokenizer_group()
271

272
273
274
275
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
276
            tokenizer_group.tokenizer = tokenizer
277
        else:
278
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
279

280
    def get_default_sampling_params(self) -> SamplingParams:
281
282
283
284
285
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
286
287
        return SamplingParams()

288
289
290
291
292
293
294
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
295
        *,
296
        use_tqdm: bool = True,
297
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
298
299
300
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
301
    ) -> list[RequestOutput]:
302
303
        ...

304
    @overload  # LEGACY: single (prompt + optional token ids)
305
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
306
307
308
309
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
310
311
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
312
        use_tqdm: bool = True,
313
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
314
315
316
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
317
    ) -> list[RequestOutput]:
318
319
320
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
321
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
322
323
    def generate(
        self,
324
        prompts: list[str],
325
        sampling_params: Optional[Union[SamplingParams,
326
327
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
328
        use_tqdm: bool = True,
329
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
330
331
332
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
333
    ) -> list[RequestOutput]:
334
335
336
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
337
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
338
339
340
341
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
342
                                        list[SamplingParams]]] = None,
343
        *,
344
        prompt_token_ids: list[int],
345
        use_tqdm: bool = True,
346
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
347
348
349
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
350
    ) -> list[RequestOutput]:
351
352
353
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
354
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
355
356
    def generate(
        self,
357
        prompts: Optional[list[str]] = None,
358
        sampling_params: Optional[Union[SamplingParams,
359
                                        list[SamplingParams]]] = None,
360
        *,
361
        prompt_token_ids: list[list[int]],
362
        use_tqdm: bool = True,
363
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
364
365
366
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
367
    ) -> list[RequestOutput]:
368
369
370
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
371
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
372
373
374
375
    def generate(
        self,
        prompts: None,
        sampling_params: None,
376
        prompt_token_ids: Union[list[int], list[list[int]]],
377
        use_tqdm: bool = True,
378
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
379
380
381
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
382
    ) -> list[RequestOutput]:
383
384
        ...

nunjunj's avatar
nunjunj committed
385
386
387
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
388
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
389
    )
390
391
    def generate(
        self,
392
        prompts: Union[Union[PromptType, Sequence[PromptType]],
393
                       Optional[Union[str, list[str]]]] = None,
394
395
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
396
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
397
        use_tqdm: bool = True,
398
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
399
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
400
        guided_options_request: Optional[Union[LLMGuidedOptions,
401
                                               GuidedDecodingRequest]] = None,
402
403
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
404
405
        """Generates the completions for the input prompts.

406
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
407
408
409
410
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
411
412
413
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
414
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
415
416
417
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
418
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
419
            use_tqdm: Whether to use tqdm to display the progress bar.
420
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
421
            prompt_adapter_request: Prompt Adapter request to use for
422
                generation, if any.
423
424
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
425
426

        Returns:
nunjunj's avatar
nunjunj committed
427
            A list of ``RequestOutput`` objects containing the
428
            generated completions in the same order as the input prompts.
429
430
431
432
433

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
434
        """
435
        runner_type = self.llm_engine.model_config.runner_type
436
        if runner_type not in ["generate", "transcription"]:
437
            messages = [
438
                "LLM.generate() is only supported for (conditional) generation "
439
440
441
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

442
443
444
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
445
                messages.append(
446
447
448
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
449
450

            raise ValueError(" ".join(messages))
451

452
        if prompt_token_ids is not None:
453
            parsed_prompts = self._convert_v1_inputs(
454
                prompts=cast(Optional[Union[str, list[str]]], prompts),
455
456
457
                prompt_token_ids=prompt_token_ids,
            )
        else:
458
459
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
460

461
462
463
464
465
466
467
468
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

469
470
        if sampling_params is None:
            # Use default sampling params.
471
            sampling_params = self.get_default_sampling_params()
472

473
        self._validate_and_add_requests(
474
            prompts=parsed_prompts,
475
476
            params=sampling_params,
            lora_request=lora_request,
477
            prompt_adapter_request=prompt_adapter_request,
478
479
            guided_options=guided_options_request,
            priority=priority)
480

481
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
482
        return self.engine_class.validate_outputs(outputs, RequestOutput)
483

484
    def collective_rpc(self,
485
                       method: Union[str, Callable[..., _R]],
486
                       timeout: Optional[float] = None,
487
488
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
511
512

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
513
514

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
515
        """
516
517
        Run a function directly on the model inside each worker,
        returning the result for each of them.
518
        """
519
520
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
521

522
523
    def beam_search(
        self,
524
        prompts: list[Union[TokensPrompt, TextPrompt]],
525
        params: BeamSearchParams,
526
    ) -> list[BeamSearchOutput]:
527
528
529
530
531
532
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
533
            params: The beam search parameters.
534
        """
535
536
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
537
538
539
540
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
541
542
543
544
545
546
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
547

548
549
550
551
552
553
554
555
556
557
558
559
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
560

561
562
563
564
565
566
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
567
                                            temperature=temperature)
568
        instances: list[BeamSearchInstance] = []
569
570

        for prompt in prompts:
571
572
573
574
575
576
577
578
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

579
580
581
582
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
583
584
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
585
586

        for _ in range(max_tokens):
587
            all_beams: list[BeamSearchSequence] = list(
588
589
590
591
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
592
            instance_start_and_end: list[tuple[int, int]] = list(
593
594
595
596
597
598
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
599
                create_tokens_prompt_from_beam(beam) for beam in all_beams
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
623
                                logprobs=current_beam.logprobs + [logprobs],
624
                                cum_logprob=current_beam.cum_logprob +
625
626
627
628
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
629
630
631
632
633
634
635

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
636
                                      key=sort_beams_key,
637
638
639
640
641
642
643
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
644
                                      key=sort_beams_key,
645
646
647
648
649
650
651
652
653
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
654
655
    def chat(
        self,
656
657
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
658
        sampling_params: Optional[Union[SamplingParams,
659
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
660
661
662
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
663
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
664
        add_generation_prompt: bool = True,
665
        continue_final_message: bool = False,
666
667
668
        tools: Optional[list[dict[str, Any]]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
669
        """
670
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
671

672
673
674
675
676
677
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
678
679

        Args:
680
681
682
683
684
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
685
686
687
688
689
690
691
692
693
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
694
695
696
697
698
699
700
701
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

702
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
703
                to each message.
704
            continue_final_message: If True, continues the final message in
705
706
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
707
708
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
709
710
711
712
713

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
714
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
715

716
717
        # Handle multi and single conversations
        if is_list_of(messages, list):
718
719
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
720
                                    messages)
721
        else:
722
            # messages is list[...]
723
            list_of_messages = [
724
                cast(list[ChatCompletionMessageParam], messages)
725
            ]
726

727
        tokenizer = self.get_tokenizer(lora_request)
728
729
730
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
731
            tools,
732
733
            chat_template_content_format,
            tokenizer,
734
            trust_remote_code=model_config.trust_remote_code,
735
736
        )

737
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
738
739

        for msgs in list_of_messages:
740
741
742
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
743
            conversation, mm_data = parse_chat_messages(
744
745
746
747
748
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
749
750

            if isinstance(tokenizer, MistralTokenizer):
751
                prompt_token_ids = apply_mistral_chat_template(
752
753
754
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
755
                    tools=tools,
756
                    add_generation_prompt=add_generation_prompt,
757
                    continue_final_message=continue_final_message,
758
759
                )
            else:
760
                prompt_str = apply_hf_chat_template(
761
                    tokenizer,
762
                    trust_remote_code=model_config.trust_remote_code,
763
764
                    conversation=conversation,
                    chat_template=chat_template,
765
                    tools=tools,
766
                    add_generation_prompt=add_generation_prompt,
767
                    continue_final_message=continue_final_message,
768
                )
769
770
771
772
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
773

774
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
775
776
777
778

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

779
780
781
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

782
            prompts.append(prompt)
783

nunjunj's avatar
nunjunj committed
784
        return self.generate(
785
            prompts,
786
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
787
788
789
790
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

791
792
793
794
795
796
797
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
798
        *,
799
        use_tqdm: bool = True,
800
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
801
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
802
    ) -> list[PoolingRequestOutput]:
803
804
        ...

805
    @overload  # LEGACY: single (prompt + optional token ids)
806
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
807
808
809
810
811
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
812
        prompt_token_ids: Optional[list[int]] = None,
813
        use_tqdm: bool = True,
814
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
815
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
816
    ) -> list[PoolingRequestOutput]:
817
        ...
818

819
    @overload  # LEGACY: multi (prompt + optional token ids)
820
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
821
822
    def encode(
        self,
823
        prompts: list[str],
824
        pooling_params: Optional[Union[PoolingParams,
825
                                       Sequence[PoolingParams]]] = None,
826
        prompt_token_ids: Optional[list[list[int]]] = None,
827
        use_tqdm: bool = True,
828
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
829
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
830
    ) -> list[PoolingRequestOutput]:
831
832
833
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
834
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
835
836
837
838
839
840
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
841
        prompt_token_ids: list[int],
842
        use_tqdm: bool = True,
843
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
844
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
845
    ) -> list[PoolingRequestOutput]:
846
847
848
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
849
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
850
851
    def encode(
        self,
852
        prompts: Optional[list[str]] = None,
853
854
855
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
856
        prompt_token_ids: list[list[int]],
857
        use_tqdm: bool = True,
858
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
859
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
860
    ) -> list[PoolingRequestOutput]:
861
862
863
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
864
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
865
866
867
868
    def encode(
        self,
        prompts: None,
        pooling_params: None,
869
        prompt_token_ids: Union[list[int], list[list[int]]],
870
        use_tqdm: bool = True,
871
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
872
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
873
    ) -> list[PoolingRequestOutput]:
874
875
        ...

nunjunj's avatar
nunjunj committed
876
877
878
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
879
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
880
    )
881
882
    def encode(
        self,
883
        prompts: Union[Union[PromptType, Sequence[PromptType]],
884
                       Optional[Union[str, list[str]]]] = None,
885
886
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
887
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
888
        use_tqdm: bool = True,
889
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
890
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
891
    ) -> list[PoolingRequestOutput]:
892
893
        """Apply pooling to the hidden states corresponding to the input
        prompts.
894

895
        This class automatically batches the given prompts, considering
896
897
898
899
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
900
901
902
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
903
904
905
906
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
907
            prompt_adapter_request: Prompt Adapter request to use for
908
                generation, if any.
909
910

        Returns:
911
            A list of ``PoolingRequestOutput`` objects containing the
912
            pooled hidden states in the same order as the input prompts.
913
914
915
916
917

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
918
        """
919
920
921
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
922

923
924
925
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
926
                messages.append(
927
928
929
930
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
931
932

            raise ValueError(" ".join(messages))
933

934
        if prompt_token_ids is not None:
935
            parsed_prompts = self._convert_v1_inputs(
936
                prompts=cast(Optional[Union[str, list[str]]], prompts),
937
938
939
                prompt_token_ids=prompt_token_ids,
            )
        else:
940
941
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
942

943
944
945
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
946
947
948
949
950
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
951

952
        self._validate_and_add_requests(
953
            prompts=parsed_prompts,
954
955
            params=pooling_params,
            lora_request=lora_request,
956
            prompt_adapter_request=prompt_adapter_request,
957
958
        )

959
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
960
        return self.engine_class.validate_outputs(outputs,
961
                                                  PoolingRequestOutput)
962

963
964
965
966
967
968
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
969
970
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
971
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
972
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
973
    ) -> list[EmbeddingRequestOutput]:
974
975
976
977
978
979
980
981
982
983
984
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
985
986
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
1002
                            pooling_params=pooling_params,
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1014
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1015
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1016
    ) -> list[ClassificationRequestOutput]:
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1048
1049
1050
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1051
1052
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1053
1054
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1055
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1056
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1057
    ) -> list[ScoringRequestOutput]:
1058

1059
        encoded_output: list[PoolingRequestOutput] = self.encode(
1060
1061
1062
1063
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1064

1065
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1066
            0:len(text_1)]
1067
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1068
            len(text_1):]
1069
1070
1071
1072

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1073
1074
1075
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1076
1077
1078
1079
1080
1081
1082

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1083
        tokenizer: AnyTokenizer,
1084
1085
        text_1: list[str],
        text_2: list[str],
1086
1087
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1088
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1089
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1090
    ) -> list[ScoringRequestOutput]:
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1103
        tokenization_kwargs: dict[str, Any] = {}
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1132
1133
1134
1135
1136
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1137
        *,
1138
1139
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1140
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1141
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1142
    ) -> list[ScoringRequestOutput]:
1143
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1144

1145
1146
1147
1148
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1149
1150
1151
1152
1153
1154
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1155
                case it has to have the same length as the ``text_2`` list
1156
1157
1158
1159
1160
1161
1162
1163
1164
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1165
            A list of ``ScoringRequestOutput`` objects containing the
1166
1167
            generated scores in the same order as the input prompts.
        """
1168
1169
1170
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1171

1172
1173
1174
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1175
                messages.append(
1176
1177
1178
1179
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1180
1181
1182

            raise ValueError(" ".join(messages))

1183
        if self.llm_engine.model_config.task not in ("embed", "score"):
1184
            raise ValueError(
1185
                "Score API is only enabled for `--task embed or --task score`")
1186
1187
1188
1189

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1190
1191
        tokenizer = self.llm_engine.get_tokenizer()

1192
1193
1194
1195
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1196
                                     "supported for scoring")
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1208
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1209
1210
1211
1212

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1213
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1214

1215
        _validate_score_input_lens(input_text_1, input_text_2)
1216

1217
        if self.llm_engine.model_config.is_cross_encoder:
1218
1219
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1220
1221
1222
1223
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1224
1225
1226
1227
1228
1229
1230
1231
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1232

1233
1234
1235
1236
1237
1238
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1239
1240
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1241

1242
1243
1244
1245
1246
1247
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1260
        """
1261
        self.reset_prefix_cache()
1262
1263
        self.llm_engine.sleep(level=level)

1264
    def wake_up(self, tags: Optional[list[str]] = None):
1265
1266
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
                ("weights", "kv_cache",). If None, all memory is reallocated.
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1277

1278
1279
    # LEGACY
    def _convert_v1_inputs(
1280
        self,
1281
1282
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1283
1284
    ):
        # skip_tokenizer_init is now checked in engine
1285

1286
1287
1288
1289
1290
1291
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1292

1293
        num_requests = None
1294
1295
        if prompts is not None:
            num_requests = len(prompts)
1296
1297
1298
1299
1300
1301
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1302
            num_requests = len(prompt_token_ids)
1303
1304
1305
1306
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1307
        parsed_prompts: list[PromptType] = []
1308
        for i in range(num_requests):
1309
            item: PromptType
1310

1311
            if prompts is not None:
1312
1313
1314
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1315
            else:
1316
                raise AssertionError
1317

1318
            parsed_prompts.append(item)
1319

1320
        return parsed_prompts
1321
1322
1323

    def _validate_and_add_requests(
        self,
1324
        prompts: Union[PromptType, Sequence[PromptType]],
1325
1326
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1327
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1328
        prompt_adapter_request: Optional[PromptAdapterRequest],
1329
        guided_options: Optional[GuidedDecodingRequest] = None,
1330
        priority: Optional[list[int]] = None,
1331
    ) -> None:
1332
1333
1334
1335
1336
1337
1338
1339
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1340
        if isinstance(prompts, (str, dict)):
1341
            # Convert a single prompt to a list.
1342
            prompts = [prompts]
1343

1344
        num_requests = len(prompts)
1345
1346
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1347
                             "must be the same.")
1348
1349
1350
1351
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1352

1353
1354
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1355
                self._add_guided_params(sp, guided_options)
1356
1357
1358

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1359

Zhuohan Li's avatar
Zhuohan Li committed
1360
        # Add requests to the engine.
1361
        for i, prompt in enumerate(prompts):
1362
            self._add_request(
1363
                prompt,
1364
                params[i] if isinstance(params, Sequence) else params,
1365
1366
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1367
                prompt_adapter_request=prompt_adapter_request,
1368
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1369
            )
1370

1371
    def _add_request(
nunjunj's avatar
nunjunj committed
1372
        self,
1373
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1374
        params: Union[SamplingParams, PoolingParams],
1375
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1376
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1377
        priority: int = 0,
1378
1379
    ) -> None:
        request_id = str(next(self.request_counter))
1380
1381
        self.llm_engine.add_request(
            request_id,
1382
            prompt,
1383
1384
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1385
            prompt_adapter_request=prompt_adapter_request,
1386
            priority=priority,
nunjunj's avatar
nunjunj committed
1387
        )
1388

1389
    def _add_guided_params(
1390
1391
1392
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1393
1394
1395
1396
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1397
            raise ValueError("Cannot set both guided_options_request and "
1398
1399
1400
1401
1402
1403
1404
1405
1406
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1407
1408
1409
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1410
1411
        return params

1412
    def _run_engine(
1413
            self, *, use_tqdm: bool
1414
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1415
1416
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1417
            num_requests = self.llm_engine.get_num_unfinished_requests()
1418
1419
1420
1421
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1422
1423
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1424
            )
1425

Zhuohan Li's avatar
Zhuohan Li committed
1426
        # Run the engine.
1427
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1428
1429
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1430
1431
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1432
            for output in step_outputs:
1433
                if output.finished:
1434
1435
                    outputs.append(output)
                    if use_tqdm:
1436
1437
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1438
                            n = len(output.outputs)
1439
                            assert output.prompt_token_ids is not None
1440
                            total_in_toks += len(output.prompt_token_ids) * n
1441
1442
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1443
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1444
1445
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1446
1447
1448
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1449
                            pbar.update(n)
1450
1451
                        else:
                            pbar.update(1)
1452

1453
1454
        if use_tqdm:
            pbar.close()
lizhigong's avatar
lizhigong committed
1455

1456
1457
1458
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1459
        return sorted(outputs, key=lambda x: int(x.request_id))