llm.py 60.4 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
4
5
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
                    Tuple, Type, Union, cast, overload)
6

7
import cloudpickle
8
import torch
9
import torch.nn as nn
10
from tqdm import tqdm
11
from typing_extensions import TypeVar, deprecated
12

13
from vllm import envs
14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
27
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
28
from vllm.logger import init_logger
29
from vllm.lora.request import LoRARequest
30
31
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
32
33
34
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
35
from vllm.pooling_params import PoolingParams
36
from vllm.prompt_adapter.request import PromptAdapterRequest
37
38
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
39
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
40
41
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
42
from vllm.usage.usage_lib import UsageContext
43
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
44

45
46
logger = init_logger(__name__)

47
48
_R = TypeVar("_R", default=Any)

49
50

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
51
52
53
54
55
56
57
58
59
60
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
61
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
62
63
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
64
65
66
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
67
68
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
69
70
71
72
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
73
74
75
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
76
77
78
79
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
80
        quantization: The method used to quantize the model weights. Currently,
81
            we support "awq", "gptq", and "fp8" (experimental).
82
83
84
85
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
86
87
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
88
89
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
90
91
92
93
94
95
96
97
98
99
100
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
101
102
103
104
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
105
106
107
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
108
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
109
            When a sequence has context length larger than this, we fall back
110
111
112
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
113
114
115
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
116
117
118
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
119
120
121
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
122
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
123
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
124

125
126
127
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
128
    """
129

130
    DEPRECATE_LEGACY: ClassVar[bool] = True
131
132
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

133
134
135
136
137
138
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

139
140
141
142
143
144
145
146
147
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

148
149
150
151
152
153
154
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
155
156
157
    def __init__(
        self,
        model: str,
158
        tokenizer: Optional[str] = None,
159
        tokenizer_mode: str = "auto",
160
        skip_tokenizer_init: bool = False,
161
        trust_remote_code: bool = False,
162
        allowed_local_media_path: str = "",
163
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
164
        dtype: str = "auto",
165
        quantization: Optional[str] = None,
166
        revision: Optional[str] = None,
167
        tokenizer_revision: Optional[str] = None,
168
169
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
170
        swap_space: float = 4,
171
        cpu_offload_gb: float = 0,
172
        enforce_eager: Optional[bool] = None,
173
        max_seq_len_to_capture: int = 8192,
174
        disable_custom_all_reduce: bool = False,
175
        disable_async_output_proc: bool = False,
176
        hf_overrides: Optional[HfOverrides] = None,
177
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
178
179
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
180
        override_pooler_config: Optional[PoolerConfig] = None,
181
        compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
182
183
        **kwargs,
    ) -> None:
184
185
186
187
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
188
        it defaults to False.
189
190
        '''

191
192
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
193

194
195
196
197
198
199
200
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

201
        if compilation_config is not None:
202
            if isinstance(compilation_config, (int, dict)):
203
204
205
206
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
207
208
209
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
210
        engine_args = EngineArgs(
211
            model=model,
212
            task=task,
213
            tokenizer=tokenizer,
214
            tokenizer_mode=tokenizer_mode,
215
            skip_tokenizer_init=skip_tokenizer_init,
216
            trust_remote_code=trust_remote_code,
217
            allowed_local_media_path=allowed_local_media_path,
218
219
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
220
            quantization=quantization,
221
            revision=revision,
222
            tokenizer_revision=tokenizer_revision,
223
224
225
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
226
            cpu_offload_gb=cpu_offload_gb,
227
            enforce_eager=enforce_eager,
228
            max_seq_len_to_capture=max_seq_len_to_capture,
229
            disable_custom_all_reduce=disable_custom_all_reduce,
230
            disable_async_output_proc=disable_async_output_proc,
231
            hf_overrides=hf_overrides,
232
            mm_processor_kwargs=mm_processor_kwargs,
233
            override_pooler_config=override_pooler_config,
234
            compilation_config=compilation_config_instance,
235
236
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
237
238
239
240
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
241
            engine_args, usage_context=UsageContext.LLM_CLASS)
242

243
244
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
245
246
247
248
249
250
251
252
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

253
254
255
256
257
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
258

259
260
261
262
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
263
            tokenizer_group.tokenizer = tokenizer
264
        else:
265
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
266

267
268
269
270
271
272
273
    def get_default_sampling_params(self) -> SamplingParams:
        diff_sampling_param = (
            self.llm_engine.model_config.get_diff_sampling_param())
        if diff_sampling_param:
            return SamplingParams.from_optional(**diff_sampling_param)
        return SamplingParams()

274
275
276
277
278
279
280
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
281
        *,
282
283
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
284
285
286
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
287
288
289
    ) -> List[RequestOutput]:
        ...

290
    @overload  # LEGACY: single (prompt + optional token ids)
291
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
292
293
294
295
296
297
298
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
299
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
300
301
302
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
303
304
305
306
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
307
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
308
309
    def generate(
        self,
310
        prompts: List[str],
311
312
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
313
        prompt_token_ids: Optional[List[List[int]]] = None,
314
        use_tqdm: bool = True,
315
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
316
317
318
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
319
320
321
322
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
323
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
324
325
326
327
328
329
330
331
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
332
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
333
334
335
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
336
337
338
339
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
340
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
341
342
343
344
345
346
347
348
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
349
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
350
351
352
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
353
354
355
356
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
357
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
358
359
360
361
362
363
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
364
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
365
366
367
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
368
369
370
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
371
372
373
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
374
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
375
    )
376
377
    def generate(
        self,
378
        prompts: Union[Union[PromptType, Sequence[PromptType]],
379
380
381
382
383
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
384
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
385
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
386
        guided_options_request: Optional[Union[LLMGuidedOptions,
387
388
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
389
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
390
391
        """Generates the completions for the input prompts.

392
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
395
396
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
397
398
399
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
400
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
401
402
403
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
404
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
405
            use_tqdm: Whether to use tqdm to display the progress bar.
406
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
407
            prompt_adapter_request: Prompt Adapter request to use for
408
                generation, if any.
409
410
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
411
412

        Returns:
nunjunj's avatar
nunjunj committed
413
            A list of ``RequestOutput`` objects containing the
414
            generated completions in the same order as the input prompts.
415
416
417
418
419

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
420
        """
421
422
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "generate":
423
            messages = [
424
                "LLM.generate() is only supported for (conditional) generation "
425
426
427
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

428
429
430
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
431
                messages.append(
432
433
434
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
435
436

            raise ValueError(" ".join(messages))
437

438
        if prompt_token_ids is not None:
439
            parsed_prompts = self._convert_v1_inputs(
440
441
442
443
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
444
445
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
446

447
448
449
450
451
452
453
454
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

455
456
        if sampling_params is None:
            # Use default sampling params.
457
            sampling_params = self.get_default_sampling_params()
458

459
        self._validate_and_add_requests(
460
            prompts=parsed_prompts,
461
462
            params=sampling_params,
            lora_request=lora_request,
463
            prompt_adapter_request=prompt_adapter_request,
464
465
            guided_options=guided_options_request,
            priority=priority)
466

467
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
468
        return self.engine_class.validate_outputs(outputs, RequestOutput)
469

470
    def collective_rpc(self,
471
                       method: Union[str, Callable[..., _R]],
472
473
                       timeout: Optional[float] = None,
                       args: Tuple = (),
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
                       kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
        executor = self.llm_engine.model_executor
        return executor.collective_rpc(method, timeout, args, kwargs)

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
501
        """
502
503
        Run a function directly on the model inside each worker,
        returning the result for each of them.
504
        """
505
506
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
507

508
509
    def beam_search(
        self,
510
        prompts: List[Union[TokensPrompt, TextPrompt]],
511
        params: BeamSearchParams,
512
513
514
515
516
517
518
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
519
520
            params: The beam search parameters.

521
522
523
524
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

525
526
527
528
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
529
530
531
532
533
534
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
535

536
537
538
539
540
541
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
542
                                            temperature=temperature)
543
544
545
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
546
547
548
549
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
590
                                logprobs=current_beam.logprobs + [logprobs],
591
592
593
594
595
596
597
598
599
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
600
                                      key=sort_beams_key,
601
602
603
604
605
606
607
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
608
                                      key=sort_beams_key,
609
610
611
612
613
614
615
616
617
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
618
619
    def chat(
        self,
620
621
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
622
623
624
625
626
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
627
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
628
        add_generation_prompt: bool = True,
629
        continue_final_message: bool = False,
630
        tools: Optional[List[Dict[str, Any]]] = None,
631
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
632
633
    ) -> List[RequestOutput]:
        """
634
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
635

636
637
638
639
640
641
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
642
643

        Args:
644
645
646
647
648
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
649
650
651
652
653
654
655
656
657
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
658
659
660
661
662
663
664
665
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

666
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
667
                to each message.
668
            continue_final_message: If True, continues the final message in
669
670
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
671
672
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
673
674
675
676
677

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
678
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
679

680
681
682
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
683
684
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
685
        else:
686
            # messages is List[...]
687
688
689
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
690

691
692
693
694
695
696
697
698
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

699
700
701
        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
702
703
704
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
705
            conversation, mm_data = parse_chat_messages(
706
707
708
709
710
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
711
712
713
714
715
716
717
718

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
719
                    continue_final_message=continue_final_message,
720
721
722
723
724
725
726
727
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
728
                    continue_final_message=continue_final_message,
729
730
731
732
733
734
735
736
737
738
739
740
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

741
742
743
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

744
            prompts.append(prompt)
745

nunjunj's avatar
nunjunj committed
746
        return self.generate(
747
            prompts,
748
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
749
750
751
752
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

753
754
755
756
757
758
759
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
760
        *,
761
762
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
763
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
764
765
766
    ) -> List[PoolingRequestOutput]:
        ...

767
    @overload  # LEGACY: single (prompt + optional token ids)
768
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
769
770
771
772
773
774
775
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
776
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
777
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
778
    ) -> List[PoolingRequestOutput]:
779
        ...
780

781
    @overload  # LEGACY: multi (prompt + optional token ids)
782
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
783
784
    def encode(
        self,
785
        prompts: List[str],
786
        pooling_params: Optional[Union[PoolingParams,
787
                                       Sequence[PoolingParams]]] = None,
788
789
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
790
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
791
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
792
    ) -> List[PoolingRequestOutput]:
793
794
795
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
796
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
797
798
799
800
801
802
803
804
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
805
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
806
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
807
    ) -> List[PoolingRequestOutput]:
808
809
810
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
811
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
812
813
814
815
816
817
818
819
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
820
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
821
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
822
    ) -> List[PoolingRequestOutput]:
823
824
825
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
826
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
827
828
829
830
831
832
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
833
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
834
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
835
    ) -> List[PoolingRequestOutput]:
836
837
        ...

nunjunj's avatar
nunjunj committed
838
839
840
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
841
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
842
    )
843
844
    def encode(
        self,
845
        prompts: Union[Union[PromptType, Sequence[PromptType]],
846
847
848
849
850
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
851
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
852
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
853
    ) -> List[PoolingRequestOutput]:
854
855
        """Apply pooling to the hidden states corresponding to the input
        prompts.
856

857
        This class automatically batches the given prompts, considering
858
859
860
861
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
862
863
864
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
865
866
867
868
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
869
            prompt_adapter_request: Prompt Adapter request to use for
870
                generation, if any.
871
872

        Returns:
873
            A list of ``PoolingRequestOutput`` objects containing the
874
            pooled hidden states in the same order as the input prompts.
875
876
877
878
879

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
880
        """
881
882
883
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
884

885
886
887
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
888
                messages.append(
889
890
891
892
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
893
894

            raise ValueError(" ".join(messages))
895

896
        if prompt_token_ids is not None:
897
            parsed_prompts = self._convert_v1_inputs(
898
899
900
901
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
902
903
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
904

905
906
907
908
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

909
        self._validate_and_add_requests(
910
            prompts=parsed_prompts,
911
912
            params=pooling_params,
            lora_request=lora_request,
913
            prompt_adapter_request=prompt_adapter_request,
914
915
        )

916
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
917
        return self.engine_class.validate_outputs(outputs,
918
                                                  PoolingRequestOutput)
919

920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[EmbeddingRequestOutput]:
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ClassificationRequestOutput]:
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
        text_1: List[Union[str, TextPrompt, TokensPrompt]],
        text_2: List[Union[str, TextPrompt, TokensPrompt]],
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ScoringRequestOutput]:

        encoded_output = self.encode(
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
        encoded_output_1 = encoded_output[0:len(text_1)]
        encoded_output_2 = encoded_output[len(text_1):]

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        output_pairs = [(t1, t2)
                        for t1, t2 in zip(encoded_output_1, encoded_output_2)]

        scores = []
        scorer = torch.nn.CosineSimilarity(0)

        for embed_1, embed_2 in output_pairs:
            pair_score = scorer(embed_1.outputs.data, embed_2.outputs.data)

            if (pad_token_id := getattr(tokenizer, "pad_token_id",
                                        None)) is not None:
                tokens = embed_1.prompt_token_ids + [
                    pad_token_id
                ] + embed_2.prompt_token_ids
            else:
                tokens = embed_1.prompt_token_ids + embed_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{embed_1.request_id}_{embed_2.request_id}",
                    outputs=pair_score,
                    prompt_token_ids=tokens,
                    finished=True))

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
        tokenizer: Union[AnyTokenizer],
        text_1: List[Union[str, TextPrompt, TokensPrompt]],
        text_2: List[Union[str, TextPrompt, TokensPrompt]],
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ScoringRequestOutput]:

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

        tokenization_kwargs: Dict[str, Any] = {}
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1101
1102
1103
1104
1105
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1106
        *,
1107
1108
1109
1110
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1111
1112
    ) -> List[ScoringRequestOutput]:
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1113

1114
1115
1116
1117
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1118
1119
1120
1121
1122
1123
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1124
                case it has to have the same length as the ``text_2`` list
1125
1126
1127
1128
1129
1130
1131
1132
1133
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1134
            A list of ``ScoringRequestOutput`` objects containing the
1135
1136
            generated scores in the same order as the input prompts.
        """
1137
1138
1139
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1140

1141
1142
1143
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1144
                messages.append(
1145
1146
1147
1148
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1149
1150
1151

            raise ValueError(" ".join(messages))

1152
        if self.llm_engine.model_config.task not in ("embed", "score"):
1153
            raise ValueError(
1154
                "Score API is only enabled for `--task embed or --task score`")
1155
1156
1157
1158

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1159
1160
        tokenizer = self.llm_engine.get_tokenizer()

1161
1162
1163
1164
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1165
                                     "supported for scoring")
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
        text_1 = [ensure_str(t) for t in text_1]

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
        text_2 = [ensure_str(t) for t in text_2]

        if len(text_1) > 1 and len(text_1) != len(text_2):
            raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
        if len(text_1) == 0:
            raise ValueError("At least one text element must be given")
        if len(text_2) == 0:
            raise ValueError("At least one text_pair element must be given")

1191
1192
1193
1194
1195
1196
1197
1198
1199
        if self.llm_engine.model_config.is_cross_encoder:
            return self._cross_encoding_score(tokenizer, text_1, text_2,
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
            return self._embedding_score(tokenizer, text_1, text_2,
                                         truncate_prompt_tokens, use_tqdm,
                                         lora_request, prompt_adapter_request)
1200

1201
1202
1203
1204
1205
1206
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1207
1208
1209
    def reset_prefix_cache(self) -> bool:
        return self.llm_engine.reset_prefix_cache()

1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

        :param level: The sleep level. Level 1 sleep will offload the model 
            weights and discard the kv cache. The content of kv cache is 
            forgotten. Level 1 sleep is good for sleeping and waking up the 
            engine to run the same model again. The model weights are backed 
            up in CPU memory. Please make sure there's enough CPU memory to 
            store the model weights. Level 2 sleep will discard both the model 
            weights and the kv cache. The content of both the model weights 
            and kv cache is forgotten. Level 2 sleep is good for sleeping and 
            waking up the engine to run a different model or update the model, 
            where previous model weights are not needed. It reduces CPU memory 
            pressure.
        """
1228
        self.reset_prefix_cache()
1229
1230
1231
        self.llm_engine.sleep(level=level)

    def wake_up(self):
1232
1233
1234
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
        for more details."""
1235
1236
        self.llm_engine.wake_up()

1237
1238
    # LEGACY
    def _convert_v1_inputs(
1239
1240
        self,
        prompts: Optional[Union[str, List[str]]],
1241
1242
1243
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
1244

1245
1246
1247
1248
1249
1250
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1251

1252
        num_requests = None
1253
1254
        if prompts is not None:
            num_requests = len(prompts)
1255
1256
1257
1258
1259
1260
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1261
            num_requests = len(prompt_token_ids)
1262
1263
1264
1265
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1266
        parsed_prompts: List[PromptType] = []
1267
        for i in range(num_requests):
1268
            item: PromptType
1269

1270
            if prompts is not None:
1271
1272
1273
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1274
            else:
1275
                raise AssertionError
1276

1277
            parsed_prompts.append(item)
1278

1279
        return parsed_prompts
1280
1281
1282

    def _validate_and_add_requests(
        self,
1283
        prompts: Union[PromptType, Sequence[PromptType]],
1284
1285
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1286
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1287
        prompt_adapter_request: Optional[PromptAdapterRequest],
1288
        guided_options: Optional[GuidedDecodingRequest] = None,
1289
        priority: Optional[List[int]] = None,
1290
    ) -> None:
1291
1292
1293
1294
1295
1296
1297
1298
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1299
        if isinstance(prompts, (str, dict)):
1300
            # Convert a single prompt to a list.
1301
            prompts = [prompts]
1302

1303
        num_requests = len(prompts)
1304
1305
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1306
                             "must be the same.")
1307
1308
1309
1310
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1311

1312
1313
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1314
                self._add_guided_params(sp, guided_options)
1315
1316
1317

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1318

Zhuohan Li's avatar
Zhuohan Li committed
1319
        # Add requests to the engine.
1320
        for i, prompt in enumerate(prompts):
1321
            self._add_request(
1322
                prompt,
1323
                params[i] if isinstance(params, Sequence) else params,
1324
1325
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1326
                prompt_adapter_request=prompt_adapter_request,
1327
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1328
            )
1329

1330
    def _add_request(
nunjunj's avatar
nunjunj committed
1331
        self,
1332
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1333
        params: Union[SamplingParams, PoolingParams],
1334
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1335
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1336
        priority: int = 0,
1337
1338
    ) -> None:
        request_id = str(next(self.request_counter))
1339
1340
        self.llm_engine.add_request(
            request_id,
1341
            prompt,
1342
1343
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1344
            prompt_adapter_request=prompt_adapter_request,
1345
            priority=priority,
nunjunj's avatar
nunjunj committed
1346
        )
1347

1348
    def _add_guided_params(
1349
1350
1351
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1367
1368
        return params

1369
    def _run_engine(
1370
            self, *, use_tqdm: bool
1371
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1372
1373
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1374
            num_requests = self.llm_engine.get_num_unfinished_requests()
1375
1376
1377
1378
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1379
1380
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1381
            )
1382

Zhuohan Li's avatar
Zhuohan Li committed
1383
        # Run the engine.
1384
        outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
1385
1386
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1387
1388
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1389
            for output in step_outputs:
1390
                if output.finished:
1391
1392
                    outputs.append(output)
                    if use_tqdm:
1393
1394
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1395
                            assert output.prompt_token_ids is not None
1396
1397
1398
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1399
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1400
1401
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1402
1403
1404
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1405
                        pbar.update(1)
1406

1407
1408
        if use_tqdm:
            pbar.close()
1409
1410
1411
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1412
        return sorted(outputs, key=lambda x: int(x.request_id))