llm.py 64.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import itertools
4
import warnings
5
from collections.abc import Sequence
6
from contextlib import contextmanager
7
from typing import Any, Callable, ClassVar, Optional, Union, cast, overload
8

9
import cloudpickle
10
import torch.nn as nn
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
16
17
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
18
19
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
20
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
21
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
22
                                         ChatTemplateContentFormatOption,
23
24
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
25
26
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
27
28
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
29
from vllm.entrypoints.utils import _validate_truncation_size
30
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
31
from vllm.inputs.parse import parse_and_batch_prompt
32
from vllm.logger import init_logger
33
from vllm.lora.request import LoRARequest
34
35
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
36
from vllm.model_executor.layers.quantization import QuantizationMethods
37
38
39
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
40
from vllm.pooling_params import PoolingParams
41
from vllm.prompt_adapter.request import PromptAdapterRequest
42
43
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
44
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
45
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
46
from vllm.usage.usage_lib import UsageContext
47
48
from vllm.utils import (Counter, Device, deprecate_args, deprecate_kwargs,
                        is_list_of)
49

50
51
logger = init_logger(__name__)

52
53
_R = TypeVar("_R", default=Any)

54
55

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
59
60
61
62
63
64
65
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
66
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
67
68
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
69
70
71
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
72
73
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
74
75
76
77
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
78
79
80
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
81
82
83
84
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
85
        quantization: The method used to quantize the model weights. Currently,
86
            we support "awq", "gptq", and "fp8" (experimental).
87
88
89
90
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
91
92
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
93
94
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
95
96
97
98
99
100
101
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
102
103
104
105
106
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
107
108
109
110
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
111
112
113
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
114
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
115
            When a sequence has context length larger than this, we fall back
116
117
118
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
119
120
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
121
122
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
123
        hf_token: The token to use as HTTP bearer authorization for remote files
124
            . If `True`, will use the token generated when running
125
            `huggingface-cli login` (stored in `~/.huggingface`).
126
127
128
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
129
130
131
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
132
133
        **kwargs: Arguments for [EngineArgs][vllm.EngineArgs]. (See
            [engine-args][])
nunjunj's avatar
nunjunj committed
134

135
136
    Note:
        This class is intended to be used for offline inference. For online
137
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
138
    """
139

140
    DEPRECATE_LEGACY: ClassVar[bool] = True
141
142
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

143
144
145
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
146
    [LLM.__init__][].
147
148
    """

149
150
151
152
153
154
155
156
157
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

158
159
160
161
162
163
164
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
165
166
167
    def __init__(
        self,
        model: str,
168
        tokenizer: Optional[str] = None,
169
        tokenizer_mode: TokenizerMode = "auto",
170
        skip_tokenizer_init: bool = False,
171
        trust_remote_code: bool = False,
172
        allowed_local_media_path: str = "",
173
        tensor_parallel_size: int = 1,
174
175
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
176
        revision: Optional[str] = None,
177
        tokenizer_revision: Optional[str] = None,
178
        seed: Optional[int] = None,
179
        gpu_memory_utilization: float = 0.9,
180
        swap_space: float = 4,
181
        cpu_offload_gb: float = 0,
182
        enforce_eager: bool = False,
183
        max_seq_len_to_capture: int = 8192,
184
        disable_custom_all_reduce: bool = False,
185
        disable_async_output_proc: bool = False,
186
        hf_token: Optional[Union[bool, str]] = None,
187
        hf_overrides: Optional[HfOverrides] = None,
188
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
189
190
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
191
        override_pooler_config: Optional[PoolerConfig] = None,
192
        compilation_config: Optional[Union[int, dict[str, Any]]] = None,
193
194
        **kwargs,
    ) -> None:
195
        """LLM constructor."""
196

197
198
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
199

200
201
202
203
204
205
206
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

207
        if compilation_config is not None:
208
209
210
211
212
213
214
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
215
216
            else:
                compilation_config_instance = compilation_config
217
218
219
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
220
        engine_args = EngineArgs(
221
            model=model,
222
            task=task,
223
            tokenizer=tokenizer,
224
            tokenizer_mode=tokenizer_mode,
225
            skip_tokenizer_init=skip_tokenizer_init,
226
            trust_remote_code=trust_remote_code,
227
            allowed_local_media_path=allowed_local_media_path,
228
229
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
230
            quantization=quantization,
231
            revision=revision,
232
            tokenizer_revision=tokenizer_revision,
233
234
235
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
236
            cpu_offload_gb=cpu_offload_gb,
237
            enforce_eager=enforce_eager,
238
            max_seq_len_to_capture=max_seq_len_to_capture,
239
            disable_custom_all_reduce=disable_custom_all_reduce,
240
            disable_async_output_proc=disable_async_output_proc,
241
            hf_token=hf_token,
242
            hf_overrides=hf_overrides,
243
            mm_processor_kwargs=mm_processor_kwargs,
244
            override_pooler_config=override_pooler_config,
245
            compilation_config=compilation_config_instance,
246
247
            **kwargs,
        )
248
249
250
251
252

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
253

254
        self.request_counter = Counter()
255
        self.default_sampling_params: Union[dict[str, Any], None] = None
256

257
258
259
260
261
262
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
263
264

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
265
        tokenizer_group = self.llm_engine.get_tokenizer_group()
266

267
268
269
270
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
271
            tokenizer_group.tokenizer = tokenizer
272
        else:
273
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
274

275
    def get_default_sampling_params(self) -> SamplingParams:
276
277
278
279
280
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
281
282
        return SamplingParams()

283
284
285
286
287
288
289
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
290
        *,
291
        use_tqdm: bool = True,
292
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
293
294
295
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
296
    ) -> list[RequestOutput]:
297
298
        ...

299
    @overload  # LEGACY: single (prompt + optional token ids)
300
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
301
302
303
304
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
305
306
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
307
        use_tqdm: bool = True,
308
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
309
310
311
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
312
    ) -> list[RequestOutput]:
313
314
315
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
316
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
317
318
    def generate(
        self,
319
        prompts: list[str],
320
        sampling_params: Optional[Union[SamplingParams,
321
322
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
323
        use_tqdm: bool = True,
324
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
325
326
327
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
328
    ) -> list[RequestOutput]:
329
330
331
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
332
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
333
334
335
336
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
337
                                        list[SamplingParams]]] = None,
338
        *,
339
        prompt_token_ids: list[int],
340
        use_tqdm: bool = True,
341
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
342
343
344
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
345
    ) -> list[RequestOutput]:
346
347
348
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
349
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
350
351
    def generate(
        self,
352
        prompts: Optional[list[str]] = None,
353
        sampling_params: Optional[Union[SamplingParams,
354
                                        list[SamplingParams]]] = None,
355
        *,
356
        prompt_token_ids: list[list[int]],
357
        use_tqdm: bool = True,
358
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
359
360
361
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
362
    ) -> list[RequestOutput]:
363
364
365
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
366
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
367
368
369
370
    def generate(
        self,
        prompts: None,
        sampling_params: None,
371
        prompt_token_ids: Union[list[int], list[list[int]]],
372
        use_tqdm: bool = True,
373
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
374
375
376
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
377
    ) -> list[RequestOutput]:
378
379
        ...

nunjunj's avatar
nunjunj committed
380
381
382
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
383
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
384
    )
385
386
    def generate(
        self,
387
        prompts: Union[Union[PromptType, Sequence[PromptType]],
388
                       Optional[Union[str, list[str]]]] = None,
389
390
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
391
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
392
        use_tqdm: bool = True,
393
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
394
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
395
        guided_options_request: Optional[Union[LLMGuidedOptions,
396
                                               GuidedDecodingRequest]] = None,
397
398
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
399
400
        """Generates the completions for the input prompts.

401
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
402
403
404
405
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
406
            prompts: The prompts to the LLM. You may pass a sequence of prompts
407
                for batch inference. See [PromptType][vllm.inputs.PromptType]
408
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
409
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
410
411
412
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
413
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
414
            use_tqdm: Whether to use tqdm to display the progress bar.
415
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
416
            prompt_adapter_request: Prompt Adapter request to use for
417
                generation, if any.
418
419
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
420
421

        Returns:
422
            A list of `RequestOutput` objects containing the
423
            generated completions in the same order as the input prompts.
424

425
426
427
428
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
429
        """
430
        runner_type = self.llm_engine.model_config.runner_type
431
        if runner_type not in ["generate", "transcription"]:
432
            messages = [
433
                "LLM.generate() is only supported for (conditional) generation "
434
435
436
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

437
438
439
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
440
                messages.append(
441
442
443
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
444
445

            raise ValueError(" ".join(messages))
446

447
        if prompt_token_ids is not None:
448
            parsed_prompts = self._convert_v1_inputs(
449
                prompts=cast(Optional[Union[str, list[str]]], prompts),
450
451
452
                prompt_token_ids=prompt_token_ids,
            )
        else:
453
454
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
455

456
457
458
459
460
461
462
463
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

464
465
        if sampling_params is None:
            # Use default sampling params.
466
            sampling_params = self.get_default_sampling_params()
467

468
        self._validate_and_add_requests(
469
            prompts=parsed_prompts,
470
            params=sampling_params,
471
            use_tqdm=use_tqdm,
472
            lora_request=lora_request,
473
            prompt_adapter_request=prompt_adapter_request,
474
            guided_options=guided_options_request,
475
476
            priority=priority,
        )
477

478
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
479
        return self.engine_class.validate_outputs(outputs, RequestOutput)
480

481
    def collective_rpc(self,
482
                       method: Union[str, Callable[..., _R]],
483
                       timeout: Optional[float] = None,
484
485
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
486
487
488
489
490
491
492
493
494
495
496
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
497
                {exc}`TimeoutError` on timeout. `None` means wait indefinitely.
498
499
500
501
502
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
503

504
505
506
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
507
        """
508
509

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
510
511

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
512
        """
513
514
        Run a function directly on the model inside each worker,
        returning the result for each of them.
515
        """
516
517
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
518

519
520
    def beam_search(
        self,
521
        prompts: list[Union[TokensPrompt, TextPrompt]],
522
        params: BeamSearchParams,
523
    ) -> list[BeamSearchOutput]:
524
525
526
527
528
529
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
530
            params: The beam search parameters.
531
        """
532
533
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
534
535
536
537
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
538
539
540
541
542
543
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
544

545
546
547
548
549
550
551
552
553
554
555
556
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
557

558
559
560
561
562
563
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
564
                                            temperature=temperature)
565
        instances: list[BeamSearchInstance] = []
566
567

        for prompt in prompts:
568
569
570
571
572
573
574
575
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

576
577
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
578
579
580
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
581

582
583
            instances.append(
                BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))
584
585

        for _ in range(max_tokens):
586
            all_beams: list[BeamSearchSequence] = list(
587
588
589
590
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
591
            instance_start_and_end: list[tuple[int, int]] = list(
592
593
594
595
596
597
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
598
                create_tokens_prompt_from_beam(beam) for beam in all_beams
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
622
                                logprobs=current_beam.logprobs + [logprobs],
623
                                cum_logprob=current_beam.cum_logprob +
624
625
626
627
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
628
629
630
631
632
633
634

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
635
                                      key=sort_beams_key,
636
637
638
639
640
641
642
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
643
                                      key=sort_beams_key,
644
645
646
647
648
649
650
651
652
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
653
654
    def chat(
        self,
655
656
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
657
        sampling_params: Optional[Union[SamplingParams,
658
                                        list[SamplingParams]]] = None,
nunjunj's avatar
nunjunj committed
659
660
661
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
662
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
663
        add_generation_prompt: bool = True,
664
        continue_final_message: bool = False,
665
        tools: Optional[list[dict[str, Any]]] = None,
666
        chat_template_kwargs: Optional[dict[str, Any]] = None,
667
668
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
669
        """
670
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
671

672
        The chat conversation is converted into a text prompt using the
673
        tokenizer and calls the [generate][] method to generate the
674
675
676
677
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
678
679

        Args:
680
681
            messages: A list of conversations or a single conversation.

682
683
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
684

nunjunj's avatar
nunjunj committed
685
686
687
688
689
690
691
692
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
693
                If not provided, the model's default chat template will be used.
694
695
            chat_template_content_format: The format to render message content.

696
697
698
699
700
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
701

702
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
703
                to each message.
704
            continue_final_message: If True, continues the final message in
705
                the conversation instead of starting a new one. Cannot be
706
                `True` if `add_generation_prompt` is also `True`.
707
708
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
709
710
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
711
712

        Returns:
713
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
714
715
            responses in the same order as the input messages.
        """
716
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
717

718
719
        # Handle multi and single conversations
        if is_list_of(messages, list):
720
721
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
722
                                    messages)
723
        else:
724
            # messages is list[...]
725
            list_of_messages = [
726
                cast(list[ChatCompletionMessageParam], messages)
727
            ]
728

729
        tokenizer = self.get_tokenizer(lora_request)
730
731
732
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
733
            tools,
734
735
            chat_template_content_format,
            tokenizer,
736
            model_config=model_config,
737
738
        )

739
740
741
742
743
744
745
746
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

747
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
748
749

        for msgs in list_of_messages:
750
751
752
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
753
            conversation, mm_data = parse_chat_messages(
754
755
756
757
758
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
759
760

            if isinstance(tokenizer, MistralTokenizer):
761
                prompt_token_ids = apply_mistral_chat_template(
762
763
                    tokenizer,
                    messages=msgs,
764
                    **_chat_template_kwargs,
765
766
                )
            else:
767
                prompt_str = apply_hf_chat_template(
768
                    tokenizer=tokenizer,
769
                    conversation=conversation,
770
                    model_config=model_config,
771
                    **_chat_template_kwargs,
772
                )
773
774
775
776
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
777

778
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
779
780
781
782

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

783
784
785
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

786
            prompts.append(prompt)
787

nunjunj's avatar
nunjunj committed
788
        return self.generate(
789
            prompts,
790
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
791
792
793
794
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

795
796
797
798
799
800
801
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
802
        *,
803
        truncate_prompt_tokens: Optional[int] = None,
804
        use_tqdm: bool = True,
805
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
806
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
807
    ) -> list[PoolingRequestOutput]:
808
809
        ...

810
    @overload  # LEGACY: single (prompt + optional token ids)
811
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
812
813
814
815
816
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
817
        prompt_token_ids: Optional[list[int]] = None,
818
        truncate_prompt_tokens: Optional[int] = None,
819
        use_tqdm: bool = True,
820
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
821
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
822
    ) -> list[PoolingRequestOutput]:
823
        ...
824

825
    @overload  # LEGACY: multi (prompt + optional token ids)
826
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
827
828
    def encode(
        self,
829
        prompts: list[str],
830
        pooling_params: Optional[Union[PoolingParams,
831
                                       Sequence[PoolingParams]]] = None,
832
        prompt_token_ids: Optional[list[list[int]]] = None,
833
        truncate_prompt_tokens: Optional[int] = None,
834
        use_tqdm: bool = True,
835
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
836
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
837
    ) -> list[PoolingRequestOutput]:
838
839
840
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
841
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
842
843
844
845
846
847
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
848
        prompt_token_ids: list[int],
849
        truncate_prompt_tokens: Optional[int] = None,
850
        use_tqdm: bool = True,
851
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
852
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
853
    ) -> list[PoolingRequestOutput]:
854
855
856
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
857
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
858
859
    def encode(
        self,
860
        prompts: Optional[list[str]] = None,
861
862
863
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
864
        prompt_token_ids: list[list[int]],
865
        truncate_prompt_tokens: Optional[int] = None,
866
        use_tqdm: bool = True,
867
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
868
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
869
    ) -> list[PoolingRequestOutput]:
870
871
872
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
873
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
874
875
876
877
    def encode(
        self,
        prompts: None,
        pooling_params: None,
878
        prompt_token_ids: Union[list[int], list[list[int]]],
879
        truncate_prompt_tokens: Optional[int] = None,
880
        use_tqdm: bool = True,
881
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
882
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
883
    ) -> list[PoolingRequestOutput]:
884
885
        ...

nunjunj's avatar
nunjunj committed
886
887
888
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
889
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
890
    )
891
892
    def encode(
        self,
893
        prompts: Union[Union[PromptType, Sequence[PromptType]],
894
                       Optional[Union[str, list[str]]]] = None,
895
896
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
897
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
898
        truncate_prompt_tokens: Optional[int] = None,
899
        use_tqdm: bool = True,
900
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
901
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
902
    ) -> list[PoolingRequestOutput]:
903
904
        """Apply pooling to the hidden states corresponding to the input
        prompts.
905

906
        This class automatically batches the given prompts, considering
907
908
909
910
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
911
            prompts: The prompts to the LLM. You may pass a sequence of prompts
912
                for batch inference. See [PromptType][vllm.inputs.PromptType]
913
                for more details about the format of each prompts.
914
915
916
917
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
918
            prompt_adapter_request: Prompt Adapter request to use for
919
                generation, if any.
920
921

        Returns:
922
            A list of `PoolingRequestOutput` objects containing the
923
            pooled hidden states in the same order as the input prompts.
924

925
926
927
928
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
929
        """
930
931
932
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
933

934
935
936
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
937
                messages.append(
938
939
940
941
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
942
943

            raise ValueError(" ".join(messages))
944

945
        if prompt_token_ids is not None:
946
            parsed_prompts = self._convert_v1_inputs(
947
                prompts=cast(Optional[Union[str, list[str]]], prompts),
948
949
950
                prompt_token_ids=prompt_token_ids,
            )
        else:
951
952
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
953

954
955
956
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
957
958
959
960
961
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
962

963
964
965
966
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

967
        self._validate_and_add_requests(
968
            prompts=parsed_prompts,
969
            params=pooling_params,
970
            use_tqdm=use_tqdm,
971
            lora_request=lora_request,
972
            tokenization_kwargs=tokenization_kwargs,
973
            prompt_adapter_request=prompt_adapter_request,
974
975
        )

976
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
977
        return self.engine_class.validate_outputs(outputs,
978
                                                  PoolingRequestOutput)
979

980
981
982
983
984
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
985
        truncate_prompt_tokens: Optional[int] = None,
986
        use_tqdm: bool = True,
987
988
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
989
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
990
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
991
    ) -> list[EmbeddingRequestOutput]:
992
993
994
995
996
997
998
999
1000
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1001
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1002
                for more details about the format of each prompts.
1003
1004
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1005
1006
1007
1008
1009
1010
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1011
            A list of `EmbeddingRequestOutput` objects containing the
1012
1013
1014
1015
1016
1017
1018
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1019
                            truncate_prompt_tokens=truncate_prompt_tokens,
1020
                            use_tqdm=use_tqdm,
1021
                            pooling_params=pooling_params,
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
1033
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1034
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1035
    ) -> list[ClassificationRequestOutput]:
1036
1037
1038
1039
1040
1041
1042
1043
1044
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1045
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1046
1047
1048
1049
1050
1051
1052
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1053
            A list of `ClassificationRequestOutput` objects containing the
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1067
1068
1069
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1070
1071
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1072
1073
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1074
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1075
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1076
    ) -> list[ScoringRequestOutput]:
1077

1078
        encoded_output: list[PoolingRequestOutput] = self.encode(
1079
            text_1 + text_2,
1080
            truncate_prompt_tokens=truncate_prompt_tokens,
1081
1082
1083
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1084

1085
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1086
            0:len(text_1)]
1087
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1088
            len(text_1):]
1089
1090
1091
1092

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1093
1094
1095
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1096
1097
1098
1099
1100
1101
1102

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1103
        tokenizer: AnyTokenizer,
1104
1105
        text_1: list[str],
        text_2: list[str],
1106
1107
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1108
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1109
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1110
    ) -> list[ScoringRequestOutput]:
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1123
        tokenization_kwargs: dict[str, Any] = {}
1124
1125
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1141
            use_tqdm=use_tqdm,
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1152
1153
1154
1155
1156
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1157
        *,
1158
1159
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
1160
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1161
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1162
    ) -> list[ScoringRequestOutput]:
1163
        """Generate similarity scores for all pairs `<text,text_pair>`.
1164

1165
1166
1167
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1168
        The input pairs are used to build a list of prompts for the
1169
1170
1171
1172
1173
1174
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1175
                case it has to have the same length as the `text_2` list
1176
            text_2: The texts to pair with the query to form the input
1177
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1178
1179
1180
1181
1182
1183
1184
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1185
            A list of `ScoringRequestOutput` objects containing the
1186
1187
            generated scores in the same order as the input prompts.
        """
1188
1189
1190
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1191

1192
1193
1194
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1195
                messages.append(
1196
1197
1198
1199
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1200
1201
1202

            raise ValueError(" ".join(messages))

1203
        if self.llm_engine.model_config.task not in ("embed", "score"):
1204
            raise ValueError(
1205
                "Score API is only enabled for `--task embed or --task score`")
1206
1207
1208
1209

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1210
1211
        tokenizer = self.llm_engine.get_tokenizer()

1212
1213
1214
1215
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1216
                                     "supported for scoring")
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1228
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1229
1230
1231
1232

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1233
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1234

1235
        _validate_score_input_lens(input_text_1, input_text_2)
1236

1237
        if self.llm_engine.model_config.is_cross_encoder:
1238
1239
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1240
1241
1242
1243
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1244
1245
1246
1247
1248
1249
1250
1251
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1252

1253
1254
1255
1256
1257
1258
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1259
1260
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1261

1262
1263
1264
1265
1266
1267
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1280
        """
1281
        self.reset_prefix_cache()
1282
1283
        self.llm_engine.sleep(level=level)

1284
    def wake_up(self, tags: Optional[list[str]] = None):
1285
        """
1286
        Wake up the engine from sleep mode. See the [sleep][] method
1287
1288
1289
1290
1291
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1292
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1293
1294
1295
1296
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1297

1298
1299
    # LEGACY
    def _convert_v1_inputs(
1300
        self,
1301
1302
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1303
1304
    ):
        # skip_tokenizer_init is now checked in engine
1305

1306
1307
1308
1309
1310
1311
1312
1313
1314
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1315
1316
1317
1318
1319
1320
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1321
1322
        if prompts is not None:
            num_requests = len(prompts)
1323
        elif prompt_token_ids is not None:
1324
            num_requests = len(prompt_token_ids)
1325
        parsed_prompts: list[PromptType] = []
1326
        for i in range(num_requests):
1327
            item: PromptType
1328

1329
            if prompts is not None:
1330
1331
1332
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1333
            else:
1334
                raise AssertionError
1335

1336
            parsed_prompts.append(item)
1337

1338
        return parsed_prompts
1339
1340
1341

    def _validate_and_add_requests(
        self,
1342
        prompts: Union[PromptType, Sequence[PromptType]],
1343
1344
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1345
1346
        *,
        use_tqdm: bool,
1347
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1348
        prompt_adapter_request: Optional[PromptAdapterRequest],
1349
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1350
        guided_options: Optional[GuidedDecodingRequest] = None,
1351
        priority: Optional[list[int]] = None,
1352
    ) -> None:
1353
1354
1355
1356
1357
1358
1359
1360
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1361
        if isinstance(prompts, (str, dict)):
1362
            # Convert a single prompt to a list.
1363
            prompts = [prompts]
1364

1365
        num_requests = len(prompts)
1366
1367
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1368
                             "must be the same.")
1369
1370
1371
1372
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1373

1374
1375
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1376
                self._add_guided_params(sp, guided_options)
1377
1378
1379

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1380

Zhuohan Li's avatar
Zhuohan Li committed
1381
        # Add requests to the engine.
1382
1383
1384
1385
1386
        it = prompts
        if use_tqdm:
            it = tqdm(it, desc="Adding requests")

        for i, prompt in enumerate(it):
1387
            self._add_request(
1388
                prompt,
1389
                params[i] if isinstance(params, Sequence) else params,
1390
                tokenization_kwargs=tokenization_kwargs,
1391
1392
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1393
                prompt_adapter_request=prompt_adapter_request,
1394
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1395
            )
1396

1397
    def _add_request(
nunjunj's avatar
nunjunj committed
1398
        self,
1399
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1400
        params: Union[SamplingParams, PoolingParams],
1401
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1402
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1403
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1404
        priority: int = 0,
1405
1406
    ) -> None:
        request_id = str(next(self.request_counter))
1407
1408
        self.llm_engine.add_request(
            request_id,
1409
            prompt,
1410
1411
            params,
            lora_request=lora_request,
1412
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1413
            prompt_adapter_request=prompt_adapter_request,
1414
            priority=priority,
nunjunj's avatar
nunjunj committed
1415
        )
1416

1417
    def _add_guided_params(
1418
1419
1420
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1421
1422
1423
1424
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1425
            raise ValueError("Cannot set both guided_options_request and "
1426
1427
1428
1429
1430
1431
1432
1433
1434
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1435
1436
1437
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1438
1439
        return params

1440
    def _run_engine(
1441
            self, *, use_tqdm: bool
1442
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1443
1444
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1445
            num_requests = self.llm_engine.get_num_unfinished_requests()
1446
1447
1448
1449
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1450
1451
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1452
            )
1453

Zhuohan Li's avatar
Zhuohan Li committed
1454
        # Run the engine.
1455
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1456
1457
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1458
1459
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1460
            for output in step_outputs:
1461
                if output.finished:
1462
1463
                    outputs.append(output)
                    if use_tqdm:
1464
1465
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1466
                            n = len(output.outputs)
1467
                            assert output.prompt_token_ids is not None
1468
                            total_in_toks += len(output.prompt_token_ids) * n
1469
1470
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1471
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1472
1473
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1474
1475
1476
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1477
                            pbar.update(n)
1478
1479
                        else:
                            pbar.update(1)
1480

1481
1482
        if use_tqdm:
            pbar.close()
1483
1484
1485
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1486
        return sorted(outputs, key=lambda x: int(x.request_id))