llm.py 59.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
from vllm.config import CompilationConfig
17
18
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
19
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
20
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
21
                                         ChatTemplateContentFormatOption,
22
23
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
24
25
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
26
27
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
28
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
29
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
30
from vllm.logger import init_logger
31
from vllm.lora.request import LoRARequest
32
33
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
34
35
36
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
37
from vllm.pooling_params import PoolingParams
38
from vllm.prompt_adapter.request import PromptAdapterRequest
39
40
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
41
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
42
43
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
44
from vllm.usage.usage_lib import UsageContext
45
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
46

47
48
logger = init_logger(__name__)

49
50
_R = TypeVar("_R", default=Any)

51
52

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
57
58
59
60
61
62
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
63
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
64
65
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
66
67
68
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
69
70
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
71
72
73
74
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
81
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
82
        quantization: The method used to quantize the model weights. Currently,
83
            we support "awq", "gptq", and "fp8" (experimental).
84
85
86
87
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
88
89
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
90
91
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
92
93
94
95
96
97
98
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
99
100
101
102
103
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
104
105
106
107
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
108
109
110
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
111
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
112
            When a sequence has context length larger than this, we fall back
113
114
115
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
116
117
118
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
119
120
121
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
122
123
124
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
125
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
126
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
127

128
129
130
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
131
    """
132

133
    DEPRECATE_LEGACY: ClassVar[bool] = True
134
135
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

136
137
138
139
140
141
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

142
143
144
145
146
147
148
149
150
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

151
152
153
154
155
156
157
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
158
159
160
    def __init__(
        self,
        model: str,
161
        tokenizer: Optional[str] = None,
162
        tokenizer_mode: str = "auto",
163
        skip_tokenizer_init: bool = False,
164
        trust_remote_code: bool = False,
165
        allowed_local_media_path: str = "",
166
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
167
        dtype: str = "auto",
168
        quantization: Optional[str] = None,
169
        revision: Optional[str] = None,
170
        tokenizer_revision: Optional[str] = None,
171
        seed: Optional[int] = None,
172
        gpu_memory_utilization: float = 0.9,
173
        swap_space: float = 4,
174
        cpu_offload_gb: float = 0,
175
        enforce_eager: Optional[bool] = None,
176
        max_seq_len_to_capture: int = 8192,
177
        disable_custom_all_reduce: bool = False,
178
        disable_async_output_proc: bool = False,
179
        hf_overrides: Optional[HfOverrides] = None,
180
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
181
182
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
183
        override_pooler_config: Optional[PoolerConfig] = None,
184
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
185
186
        **kwargs,
    ) -> None:
187
188
189
190
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
191
        it defaults to False.
192
193
        '''

194
195
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
196

197
198
199
200
201
202
203
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

204
        if compilation_config is not None:
205
            if isinstance(compilation_config, (int, dict)):
206
207
208
209
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
210
211
212
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
213
        engine_args = EngineArgs(
214
            model=model,
215
            task=task,
216
            tokenizer=tokenizer,
217
            tokenizer_mode=tokenizer_mode,
218
            skip_tokenizer_init=skip_tokenizer_init,
219
            trust_remote_code=trust_remote_code,
220
            allowed_local_media_path=allowed_local_media_path,
221
222
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
223
            quantization=quantization,
224
            revision=revision,
225
            tokenizer_revision=tokenizer_revision,
226
227
228
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
229
            cpu_offload_gb=cpu_offload_gb,
230
            enforce_eager=enforce_eager,
231
            max_seq_len_to_capture=max_seq_len_to_capture,
232
            disable_custom_all_reduce=disable_custom_all_reduce,
233
            disable_async_output_proc=disable_async_output_proc,
234
            hf_overrides=hf_overrides,
235
            mm_processor_kwargs=mm_processor_kwargs,
236
            override_pooler_config=override_pooler_config,
237
            compilation_config=compilation_config_instance,
238
239
            **kwargs,
        )
240
241
242
243
244

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
245

246
        self.request_counter = Counter()
247
        self.default_sampling_params: Union[dict[str, Any], None] = None
248

249
250
251
252
253
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
254

255
256
257
258
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
259
            tokenizer_group.tokenizer = tokenizer
260
        else:
261
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
262

263
    def get_default_sampling_params(self) -> SamplingParams:
264
265
266
267
268
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
269
270
        return SamplingParams()

271
272
273
274
275
276
277
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
278
        *,
279
        use_tqdm: bool = True,
280
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
281
282
283
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
284
    ) -> list[RequestOutput]:
285
286
        ...

287
    @overload  # LEGACY: single (prompt + optional token ids)
288
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
289
290
291
292
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
293
294
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
295
        use_tqdm: bool = True,
296
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
297
298
299
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
300
    ) -> list[RequestOutput]:
301
302
303
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
304
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
305
306
    def generate(
        self,
307
        prompts: list[str],
308
        sampling_params: Optional[Union[SamplingParams,
309
310
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
311
        use_tqdm: bool = True,
312
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
313
314
315
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
316
    ) -> list[RequestOutput]:
317
318
319
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
320
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
321
322
323
324
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
325
                                        list[SamplingParams]]] = None,
326
        *,
327
        prompt_token_ids: list[int],
328
        use_tqdm: bool = True,
329
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
330
331
332
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
333
    ) -> list[RequestOutput]:
334
335
336
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
337
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
338
339
    def generate(
        self,
340
        prompts: Optional[list[str]] = None,
341
        sampling_params: Optional[Union[SamplingParams,
342
                                        list[SamplingParams]]] = None,
343
        *,
344
        prompt_token_ids: list[list[int]],
345
        use_tqdm: bool = True,
346
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
347
348
349
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
350
    ) -> list[RequestOutput]:
351
352
353
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
354
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
355
356
357
358
    def generate(
        self,
        prompts: None,
        sampling_params: None,
359
        prompt_token_ids: Union[list[int], list[list[int]]],
360
        use_tqdm: bool = True,
361
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
362
363
364
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
365
    ) -> list[RequestOutput]:
366
367
        ...

nunjunj's avatar
nunjunj committed
368
369
370
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
371
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
372
    )
373
374
    def generate(
        self,
375
        prompts: Union[Union[PromptType, Sequence[PromptType]],
376
                       Optional[Union[str, list[str]]]] = None,
377
378
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
379
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
380
        use_tqdm: bool = True,
381
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
382
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
383
        guided_options_request: Optional[Union[LLMGuidedOptions,
384
                                               GuidedDecodingRequest]] = None,
385
386
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388
        """Generates the completions for the input prompts.

389
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
390
391
392
393
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
394
395
396
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
397
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
398
399
400
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
401
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
402
            use_tqdm: Whether to use tqdm to display the progress bar.
403
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
404
            prompt_adapter_request: Prompt Adapter request to use for
405
                generation, if any.
406
407
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
408
409

        Returns:
nunjunj's avatar
nunjunj committed
410
            A list of ``RequestOutput`` objects containing the
411
            generated completions in the same order as the input prompts.
412
413
414
415
416

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
417
        """
418
        runner_type = self.llm_engine.model_config.runner_type
419
        if runner_type not in ["generate", "transcription"]:
420
            messages = [
421
                "LLM.generate() is only supported for (conditional) generation "
422
423
424
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

425
426
427
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
428
                messages.append(
429
430
431
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
432
433

            raise ValueError(" ".join(messages))
434

435
        if prompt_token_ids is not None:
436
            parsed_prompts = self._convert_v1_inputs(
437
                prompts=cast(Optional[Union[str, list[str]]], prompts),
438
439
440
                prompt_token_ids=prompt_token_ids,
            )
        else:
441
442
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
443

444
445
446
447
448
449
450
451
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

452
453
        if sampling_params is None:
            # Use default sampling params.
454
            sampling_params = self.get_default_sampling_params()
455

456
        self._validate_and_add_requests(
457
            prompts=parsed_prompts,
458
459
            params=sampling_params,
            lora_request=lora_request,
460
            prompt_adapter_request=prompt_adapter_request,
461
462
            guided_options=guided_options_request,
            priority=priority)
463

464
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
465
        return self.engine_class.validate_outputs(outputs, RequestOutput)
466

467
    def collective_rpc(self,
468
                       method: Union[str, Callable[..., _R]],
469
                       timeout: Optional[float] = None,
470
471
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
        executor = self.llm_engine.model_executor
        return executor.collective_rpc(method, timeout, args, kwargs)

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
498
        """
499
500
        Run a function directly on the model inside each worker,
        returning the result for each of them.
501
        """
502
503
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
504

505
506
    def beam_search(
        self,
507
        prompts: list[Union[TokensPrompt, TextPrompt]],
508
        params: BeamSearchParams,
509
    ) -> list[BeamSearchOutput]:
510
511
512
513
514
515
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
516
517
            params: The beam search parameters.

518
519
520
521
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

522
523
524
525
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
526
527
528
529
530
531
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
532

533
534
535
536
537
538
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
539
                                            temperature=temperature)
540
        instances: list[BeamSearchInstance] = []
541
542

        for prompt in prompts:
543
544
545
546
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
547
548
549
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
550
            all_beams: list[BeamSearchSequence] = list(
551
552
553
554
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
555
            instance_start_and_end: list[tuple[int, int]] = list(
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
587
                                logprobs=current_beam.logprobs + [logprobs],
588
589
590
591
592
593
594
595
596
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
597
                                      key=sort_beams_key,
598
599
600
601
602
603
604
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
605
                                      key=sort_beams_key,
606
607
608
609
610
611
612
613
614
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
615
616
    def chat(
        self,
617
618
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
619
        sampling_params: Optional[Union[SamplingParams,
620
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
621
622
623
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
624
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
625
        add_generation_prompt: bool = True,
626
        continue_final_message: bool = False,
627
628
629
        tools: Optional[list[dict[str, Any]]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
630
        """
631
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
632

633
634
635
636
637
638
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
639
640

        Args:
641
642
643
644
645
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
646
647
648
649
650
651
652
653
654
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
655
656
657
658
659
660
661
662
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

663
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
664
                to each message.
665
            continue_final_message: If True, continues the final message in
666
667
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
668
669
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
670
671
672
673
674

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
675
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
676

677
678
        # Handle multi and single conversations
        if is_list_of(messages, list):
679
680
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
681
                                    messages)
682
        else:
683
            # messages is list[...]
684
            list_of_messages = [
685
                cast(list[ChatCompletionMessageParam], messages)
686
            ]
687

688
689
690
691
692
693
694
695
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

696
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
697
698

        for msgs in list_of_messages:
699
700
701
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
702
            conversation, mm_data = parse_chat_messages(
703
704
705
706
707
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
708

709
            prompt_data: Union[str, list[int]]
710
711
712
713
714
715
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
716
                    continue_final_message=continue_final_message,
717
718
719
720
721
722
723
724
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
725
                    continue_final_message=continue_final_message,
726
727
728
729
730
731
732
733
734
735
736
737
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

738
739
740
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

741
            prompts.append(prompt)
742

nunjunj's avatar
nunjunj committed
743
        return self.generate(
744
            prompts,
745
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
746
747
748
749
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

750
751
752
753
754
755
756
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
757
        *,
758
        use_tqdm: bool = True,
759
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
760
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
761
    ) -> list[PoolingRequestOutput]:
762
763
        ...

764
    @overload  # LEGACY: single (prompt + optional token ids)
765
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
766
767
768
769
770
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
771
        prompt_token_ids: Optional[list[int]] = None,
772
        use_tqdm: bool = True,
773
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
774
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
775
    ) -> list[PoolingRequestOutput]:
776
        ...
777

778
    @overload  # LEGACY: multi (prompt + optional token ids)
779
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
780
781
    def encode(
        self,
782
        prompts: list[str],
783
        pooling_params: Optional[Union[PoolingParams,
784
                                       Sequence[PoolingParams]]] = None,
785
        prompt_token_ids: Optional[list[list[int]]] = None,
786
        use_tqdm: bool = True,
787
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
788
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
789
    ) -> list[PoolingRequestOutput]:
790
791
792
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
793
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
794
795
796
797
798
799
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
800
        prompt_token_ids: list[int],
801
        use_tqdm: bool = True,
802
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
803
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
804
    ) -> list[PoolingRequestOutput]:
805
806
807
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
808
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
809
810
    def encode(
        self,
811
        prompts: Optional[list[str]] = None,
812
813
814
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
815
        prompt_token_ids: list[list[int]],
816
        use_tqdm: bool = True,
817
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
818
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
819
    ) -> list[PoolingRequestOutput]:
820
821
822
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
823
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
824
825
826
827
    def encode(
        self,
        prompts: None,
        pooling_params: None,
828
        prompt_token_ids: Union[list[int], list[list[int]]],
829
        use_tqdm: bool = True,
830
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
831
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
832
    ) -> list[PoolingRequestOutput]:
833
834
        ...

nunjunj's avatar
nunjunj committed
835
836
837
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
838
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
839
    )
840
841
    def encode(
        self,
842
        prompts: Union[Union[PromptType, Sequence[PromptType]],
843
                       Optional[Union[str, list[str]]]] = None,
844
845
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
846
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
847
        use_tqdm: bool = True,
848
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
849
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
850
    ) -> list[PoolingRequestOutput]:
851
852
        """Apply pooling to the hidden states corresponding to the input
        prompts.
853

854
        This class automatically batches the given prompts, considering
855
856
857
858
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
859
860
861
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
862
863
864
865
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
866
            prompt_adapter_request: Prompt Adapter request to use for
867
                generation, if any.
868
869

        Returns:
870
            A list of ``PoolingRequestOutput`` objects containing the
871
            pooled hidden states in the same order as the input prompts.
872
873
874
875
876

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
877
        """
878
879
880
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
881

882
883
884
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
885
                messages.append(
886
887
888
889
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
890
891

            raise ValueError(" ".join(messages))
892

893
        if prompt_token_ids is not None:
894
            parsed_prompts = self._convert_v1_inputs(
895
                prompts=cast(Optional[Union[str, list[str]]], prompts),
896
897
898
                prompt_token_ids=prompt_token_ids,
            )
        else:
899
900
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
901

902
903
904
905
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

906
        self._validate_and_add_requests(
907
            prompts=parsed_prompts,
908
909
            params=pooling_params,
            lora_request=lora_request,
910
            prompt_adapter_request=prompt_adapter_request,
911
912
        )

913
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
914
        return self.engine_class.validate_outputs(outputs,
915
                                                  PoolingRequestOutput)
916

917
918
919
920
921
922
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
923
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
924
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
925
    ) -> list[EmbeddingRequestOutput]:
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
963
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
964
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
965
    ) -> list[ClassificationRequestOutput]:
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

997
998
999
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1000
1001
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1002
1003
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1004
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1005
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1006
    ) -> list[ScoringRequestOutput]:
1007

1008
        encoded_output: list[PoolingRequestOutput] = self.encode(
1009
1010
1011
1012
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1013

1014
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1015
            0:len(text_1)]
1016
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1017
            len(text_1):]
1018
1019
1020
1021

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1022
        scores: list[PoolingRequestOutput] = []
1023

1024
1025
1026
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1027
1028
1029
1030
1031
1032
1033

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1034
        tokenizer: AnyTokenizer,
1035
1036
        text_1: list[str],
        text_2: list[str],
1037
1038
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1039
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1040
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1041
    ) -> list[ScoringRequestOutput]:
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1054
        tokenization_kwargs: dict[str, Any] = {}
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1083
1084
1085
1086
1087
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1088
        *,
1089
1090
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1091
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1092
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1093
    ) -> list[ScoringRequestOutput]:
1094
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1095

1096
1097
1098
1099
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1100
1101
1102
1103
1104
1105
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1106
                case it has to have the same length as the ``text_2`` list
1107
1108
1109
1110
1111
1112
1113
1114
1115
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1116
            A list of ``ScoringRequestOutput`` objects containing the
1117
1118
            generated scores in the same order as the input prompts.
        """
1119
1120
1121
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1122

1123
1124
1125
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1126
                messages.append(
1127
1128
1129
1130
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1131
1132
1133

            raise ValueError(" ".join(messages))

1134
        if self.llm_engine.model_config.task not in ("embed", "score"):
1135
            raise ValueError(
1136
                "Score API is only enabled for `--task embed or --task score`")
1137
1138
1139
1140

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1141
1142
        tokenizer = self.llm_engine.get_tokenizer()

1143
1144
1145
1146
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1147
                                     "supported for scoring")
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1159
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1160
1161
1162
1163

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1164
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1165

1166
        _validate_score_input_lens(input_text_1, input_text_2)
1167

1168
        if self.llm_engine.model_config.is_cross_encoder:
1169
1170
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1171
1172
1173
1174
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1175
1176
1177
1178
1179
1180
1181
1182
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1183

1184
1185
1186
1187
1188
1189
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1190
1191
1192
    def reset_prefix_cache(self) -> bool:
        return self.llm_engine.reset_prefix_cache()

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

        :param level: The sleep level. Level 1 sleep will offload the model 
            weights and discard the kv cache. The content of kv cache is 
            forgotten. Level 1 sleep is good for sleeping and waking up the 
            engine to run the same model again. The model weights are backed 
            up in CPU memory. Please make sure there's enough CPU memory to 
            store the model weights. Level 2 sleep will discard both the model 
            weights and the kv cache. The content of both the model weights 
            and kv cache is forgotten. Level 2 sleep is good for sleeping and 
            waking up the engine to run a different model or update the model, 
            where previous model weights are not needed. It reduces CPU memory 
            pressure.
        """
1211
        self.reset_prefix_cache()
1212
1213
1214
        self.llm_engine.sleep(level=level)

    def wake_up(self):
1215
1216
1217
        """
        Wake up the engine from sleep mode. See the :meth:`sleep` method
        for more details."""
1218
1219
        self.llm_engine.wake_up()

1220
1221
    # LEGACY
    def _convert_v1_inputs(
1222
        self,
1223
1224
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1225
1226
    ):
        # skip_tokenizer_init is now checked in engine
1227

1228
1229
1230
1231
1232
1233
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1234

1235
        num_requests = None
1236
1237
        if prompts is not None:
            num_requests = len(prompts)
1238
1239
1240
1241
1242
1243
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1244
            num_requests = len(prompt_token_ids)
1245
1246
1247
1248
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1249
        parsed_prompts: list[PromptType] = []
1250
        for i in range(num_requests):
1251
            item: PromptType
1252

1253
            if prompts is not None:
1254
1255
1256
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1257
            else:
1258
                raise AssertionError
1259

1260
            parsed_prompts.append(item)
1261

1262
        return parsed_prompts
1263
1264
1265

    def _validate_and_add_requests(
        self,
1266
        prompts: Union[PromptType, Sequence[PromptType]],
1267
1268
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1269
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1270
        prompt_adapter_request: Optional[PromptAdapterRequest],
1271
        guided_options: Optional[GuidedDecodingRequest] = None,
1272
        priority: Optional[list[int]] = None,
1273
    ) -> None:
1274
1275
1276
1277
1278
1279
1280
1281
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1282
        if isinstance(prompts, (str, dict)):
1283
            # Convert a single prompt to a list.
1284
            prompts = [prompts]
1285

1286
        num_requests = len(prompts)
1287
1288
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1289
                             "must be the same.")
1290
1291
1292
1293
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1294

1295
1296
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1297
                self._add_guided_params(sp, guided_options)
1298
1299
1300

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1301

Zhuohan Li's avatar
Zhuohan Li committed
1302
        # Add requests to the engine.
1303
        for i, prompt in enumerate(prompts):
1304
            self._add_request(
1305
                prompt,
1306
                params[i] if isinstance(params, Sequence) else params,
1307
1308
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1309
                prompt_adapter_request=prompt_adapter_request,
1310
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1311
            )
1312

1313
    def _add_request(
nunjunj's avatar
nunjunj committed
1314
        self,
1315
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1316
        params: Union[SamplingParams, PoolingParams],
1317
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1318
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1319
        priority: int = 0,
1320
1321
    ) -> None:
        request_id = str(next(self.request_counter))
1322
1323
        self.llm_engine.add_request(
            request_id,
1324
            prompt,
1325
1326
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1327
            prompt_adapter_request=prompt_adapter_request,
1328
            priority=priority,
nunjunj's avatar
nunjunj committed
1329
        )
1330

1331
    def _add_guided_params(
1332
1333
1334
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1335
1336
1337
1338
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1339
            raise ValueError("Cannot set both guided_options_request and "
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1350
1351
        return params

1352
    def _run_engine(
1353
            self, *, use_tqdm: bool
1354
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1355
1356
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1357
            num_requests = self.llm_engine.get_num_unfinished_requests()
1358
1359
1360
1361
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1362
1363
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1364
            )
1365

Zhuohan Li's avatar
Zhuohan Li committed
1366
        # Run the engine.
1367
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1368
1369
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1370
1371
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1372
            for output in step_outputs:
1373
                if output.finished:
1374
1375
                    outputs.append(output)
                    if use_tqdm:
1376
1377
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1378
                            n = len(output.outputs)
1379
                            assert output.prompt_token_ids is not None
1380
                            total_in_toks += len(output.prompt_token_ids) * n
1381
1382
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1383
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1384
1385
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1386
1387
1388
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1389
                            pbar.update(n)
1390
1391
                        else:
                            pbar.update(1)
1392

1393
1394
        if use_tqdm:
            pbar.close()
1395
1396
1397
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1398
        return sorted(outputs, key=lambda x: int(x.request_id))