llm.py 69.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
import vllm.envs as envs
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
16
17
                              BeamSearchSequence,
                              create_sort_beams_key_function)
18
19
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
20
21
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
# yapf conflicts with isort for this block
# yapf: disable
31
32
33
34
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
35
                                          compress_token_type_ids,
36
                                          get_score_prompt)
37
# yapf: enable
38
39
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
40
41
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
42
from vllm.logger import init_logger
43
from vllm.lora.request import LoRARequest
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
46
47
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
48
from vllm.plugins.io_processors import get_io_processor
49
from vllm.pooling_params import PoolingParams
50
51
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
52
from vllm.tasks import PoolingTask
53
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
54
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
55
from vllm.usage.usage_lib import UsageContext
56
from vllm.utils import Counter, Device, as_iter, is_list_of
57
from vllm.v1.sample.logits_processor import LogitsProcessor
58

59
import vllm.envs as envs
lizhigong's avatar
lizhigong committed
60
from vllm.zero_overhead.llm_engine import ZeroOverheadEngine
61

62

63
64
65
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

66
67
logger = init_logger(__name__)

68
69
_R = TypeVar("_R", default=Any)

70
71

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
74
75
76
77
78
79
80
81
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
82
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
83
84
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
85
86
87
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
88
89
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
90
91
92
93
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
101
        quantization: The method used to quantize the model weights. Currently,
102
            we support "awq", "gptq", and "fp8" (experimental).
103
104
105
106
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
107
108
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
109
110
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
111
112
113
114
115
116
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
117
118
119
120
121
122
123
124
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
            compared with using gpu_memory_memory_utilization. Note that
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
125
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
126
127
128
129
130
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
131
132
133
134
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
135
136
137
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
138
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
139
            When a sequence has context length larger than this, we fall back
140
141
142
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
143
144
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
145
146
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
147
        hf_token: The token to use as HTTP bearer authorization for remote files
148
            . If `True`, will use the token generated when running
149
            `huggingface-cli login` (stored in `~/.huggingface`).
150
151
152
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
153
154
155
156
157
158
159
160
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
161
162
163
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
164
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
165

166
167
    Note:
        This class is intended to be used for offline inference. For online
168
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
169
    """
170
171
172
173

    def __init__(
        self,
        model: str,
174
        *,
175
176
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
177
        tokenizer: Optional[str] = None,
178
        #need change mode as "cpm" for 9g tokenizer
zhuwenwen's avatar
zhuwenwen committed
179
        # tokenizer_mode: TokenizerMode = "cpm",
180
        tokenizer_mode: TokenizerMode = "auto",
181
        skip_tokenizer_init: bool = False,
182
        trust_remote_code: bool = False,
183
        allowed_local_media_path: str = "",
184
        tensor_parallel_size: int = 1,
185
186
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
187
        revision: Optional[str] = None,
188
        tokenizer_revision: Optional[str] = None,
189
        seed: Optional[int] = None,
190
        gpu_memory_utilization: float = 0.9,
191
        swap_space: float = 4,
192
        cpu_offload_gb: float = 0,
193
        enforce_eager: bool = False,
194
        max_seq_len_to_capture: int = 8192,
195
        disable_custom_all_reduce: bool = False,
196
        disable_async_output_proc: bool = False,
197
        hf_token: Optional[Union[bool, str]] = None,
198
        hf_overrides: Optional[HfOverrides] = None,
199
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
200
        override_pooler_config: Optional[PoolerConfig] = None,
201
        kv_cache_memory_bytes: Optional[int] = None,
202
203
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
204
205
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
206
        **kwargs: Any,
207
    ) -> None:
208
        """LLM constructor."""
209

210
211
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
212

213
214
215
216
217
218
219
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

220
221
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
222
            from vllm.config.kv_transfer import KVTransferConfig
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

237
238
239
        if hf_overrides is None:
            hf_overrides = {}

240
        if compilation_config is not None:
241
242
243
244
245
246
247
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
248
249
            else:
                compilation_config_instance = compilation_config
250
        else:
251
            compilation_config_instance = CompilationConfig()
252

Zhuohan Li's avatar
Zhuohan Li committed
253
        engine_args = EngineArgs(
254
            model=model,
255
256
            runner=runner,
            convert=convert,
257
            tokenizer=tokenizer,
258
            tokenizer_mode=tokenizer_mode,
259
            skip_tokenizer_init=skip_tokenizer_init,
260
            trust_remote_code=trust_remote_code,
261
            allowed_local_media_path=allowed_local_media_path,
262
263
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
264
            quantization=quantization,
265
            revision=revision,
266
            tokenizer_revision=tokenizer_revision,
267
268
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
269
            kv_cache_memory_bytes=kv_cache_memory_bytes,
270
            swap_space=swap_space,
271
            cpu_offload_gb=cpu_offload_gb,
272
            enforce_eager=enforce_eager,
273
            max_seq_len_to_capture=max_seq_len_to_capture,
274
            disable_custom_all_reduce=disable_custom_all_reduce,
275
            disable_async_output_proc=disable_async_output_proc,
276
            hf_token=hf_token,
277
            hf_overrides=hf_overrides,
278
            mm_processor_kwargs=mm_processor_kwargs,
279
            override_pooler_config=override_pooler_config,
280
            compilation_config=compilation_config_instance,
281
            logits_processors=logits_processors,
282
283
            **kwargs,
        )
284

285
286
        log_non_default_args(engine_args)

287
        # Create the Engine (autoselects V0 vs V1)
288
        if envs.VLLM_ZERO_OVERHEAD:
lizhigong's avatar
lizhigong committed
289
290
291
292
293
            self.llm_engine = ZeroOverheadEngine.from_engine_args(
                engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        else:
            self.llm_engine = LLMEngine.from_engine_args(
                engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
294
        self.engine_class = type(self.llm_engine)
295

296
        self.request_counter = Counter()
297
        self.default_sampling_params: Union[dict[str, Any], None] = None
298

299
300
301
302
303
304
305
306
307
308
        if envs.VLLM_USE_V1:
            supported_tasks = self.llm_engine \
                .get_supported_tasks()  # type: ignore
        else:
            supported_tasks = self.llm_engine.model_config.supported_tasks

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

309
310
311
312
313
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

314
315
316
317
318
319
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
320
321

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
322
        tokenizer_group = self.llm_engine.get_tokenizer_group()
323

324
325
326
327
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
328
            tokenizer_group.tokenizer = tokenizer
329
        else:
330
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
331

332
    def get_default_sampling_params(self) -> SamplingParams:
333
334
335
336
337
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
338
339
        return SamplingParams()

340
341
342
343
344
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
345
        *,
346
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
347
348
349
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
        """Generates the completions for the input prompts.

352
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
353
354
355
356
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
357
            prompts: The prompts to the LLM. You may pass a sequence of prompts
358
                for batch inference. See [PromptType][vllm.inputs.PromptType]
359
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
360
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
361
362
363
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
364
                prompts and it is paired one by one with the prompt.
365
366
367
368
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
369
            lora_request: LoRA request to use for generation, if any.
370
371
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
372
373

        Returns:
374
            A list of `RequestOutput` objects containing the
375
            generated completions in the same order as the input prompts.
376
377

        Note:
378
            Using `prompts` and `prompt_token_ids` as keyword parameters is
379
            considered legacy and may be deprecated in the future. You should
380
            instead pass them via the `inputs` parameter.
381
        """
382
383
384
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
385
386
387
388
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
389

390
391
        if sampling_params is None:
            # Use default sampling params.
392
            sampling_params = self.get_default_sampling_params()
393

394
395
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
396
            prompts, lora_request)
397

398
        self._validate_and_add_requests(
399
            prompts=prompts,
400
            params=sampling_params,
401
            use_tqdm=use_tqdm,
402
            lora_request=lora_request,
403
404
            priority=priority,
        )
405

406
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
407
        return self.engine_class.validate_outputs(outputs, RequestOutput)
408

409
    def _get_modality_specific_lora_reqs(
410
            self, prompts: Union[PromptType, Sequence[PromptType]],
411
412
413
414
415
416
417
418
419
420
421
422
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

423
424
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
425

426
        optional_loras = ([lora_request] * len(prompts)
427
428
429
430
431
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
432
                prompt,
433
434
                opt_lora_req,
                lora_config.default_mm_loras,
435
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
436
437
        ]

438
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
439
440
441
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
442
443
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
444
445
            return lora_request

446
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
447

448
449
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

484
    def collective_rpc(self,
485
                       method: Union[str, Callable[..., _R]],
486
                       timeout: Optional[float] = None,
487
488
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
489
490
491
492
493
494
495
496
497
498
499
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
500
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
501
502
503
504
505
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
506

507
508
509
510
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
511
512

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
513
514

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
515
        """
516
517
        Run a function directly on the model inside each worker,
        returning the result for each of them.
518
        """
519
520
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
521

522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

538
539
    def beam_search(
        self,
540
        prompts: list[Union[TokensPrompt, TextPrompt]],
541
        params: BeamSearchParams,
542
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
543
        use_tqdm: bool = False,
544
        concurrency_limit: Optional[int] = None,
545
    ) -> list[BeamSearchOutput]:
546
547
548
549
550
551
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
552
            params: The beam search parameters.
553
            lora_request: LoRA request to use for generation, if any.
554
            use_tqdm: Whether to use tqdm to display the progress bar.
555
556
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
557
        """
558
559
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
560
561
562
563
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
564
565
        length_penalty = params.length_penalty

566
567
568
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

569
570
571
572
573
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
574

575
576
577
578
579
580
581
582
583
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

584
585
586
587
588
589
590
591
592
593
594
595
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
596

597
598
599
600
601
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
602
                                            temperature=temperature)
603
        instances: list[BeamSearchInstance] = []
604

605
        for lora_req, prompt in zip(lora_requests, prompts):
606
607
608
609
610
611
612
613
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

614
615
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
616
617
618
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
619

620
            instances.append(
621
622
623
624
625
626
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
627

628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
701
702
703
704
705

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
706
                                      key=sort_beams_key,
707
708
709
710
711
712
713
714
715
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

716
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
717
        self,
718
719
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
720
721
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
722
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
723
        add_generation_prompt: bool = True,
724
        continue_final_message: bool = False,
725
        tools: Optional[list[dict[str, Any]]] = None,
726
        chat_template_kwargs: Optional[dict[str, Any]] = None,
727
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
728
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
729
        """
730
731
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
732

733
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
734
        Returns:
735
736
737
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
738
        """
739
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
740

741
742
        # Handle multi and single conversations
        if is_list_of(messages, list):
743
744
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
745
                                    messages)
746
        else:
747
            # messages is list[...]
748
            list_of_messages = [
749
                cast(list[ChatCompletionMessageParam], messages)
750
            ]
751

752
        tokenizer = self.get_tokenizer(lora_request)
753
754
755
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
756
            tools,
757
758
            chat_template_content_format,
            tokenizer,
759
            model_config=model_config,
760
761
        )

762
763
764
765
766
767
768
769
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

770
        prompts: list[TokensPrompt] = []
771
772

        for msgs in list_of_messages:
773
774
775
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
776
            conversation, mm_data, mm_uuids = parse_chat_messages(
777
778
779
780
781
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
782
783

            if isinstance(tokenizer, MistralTokenizer):
784
                prompt_token_ids = apply_mistral_chat_template(
785
786
                    tokenizer,
                    messages=msgs,
787
                    **_chat_template_kwargs,
788
789
                )
            else:
790
                prompt_str = apply_hf_chat_template(
791
                    tokenizer=tokenizer,
792
                    conversation=conversation,
793
                    model_config=model_config,
794
                    **_chat_template_kwargs,
795
                )
796
797
798
799
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
800

801
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
802
803
804
805

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

806
807
808
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

809
810
811
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

812
            prompts.append(prompt)
813

814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
        return prompts

    def chat(
        self,
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
        sampling_params: Optional[Union[SamplingParams,
                                        list[SamplingParams]]] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            lora_request=lora_request,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
895
        return self.generate(
896
            prompts,
897
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
898
899
900
901
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

902
903
    def encode(
        self,
904
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
905
906
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
907
        *,
908
        truncate_prompt_tokens: Optional[int] = None,
909
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
910
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
911
        pooling_task: PoolingTask = "encode",
912
        tokenization_kwargs: Optional[dict[str, Any]] = None,
913
    ) -> list[PoolingRequestOutput]:
914
915
        """Apply pooling to the hidden states corresponding to the input
        prompts.
916

917
        This class automatically batches the given prompts, considering
918
919
920
921
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
922
            prompts: The prompts to the LLM. You may pass a sequence of prompts
923
                for batch inference. See [PromptType][vllm.inputs.PromptType]
924
                for more details about the format of each prompt.
925
926
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
927
928
929
930
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
931
            lora_request: LoRA request to use for generation, if any.
932
            pooling_task: Override the pooling task to use.
933
934
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
935
936

        Returns:
937
            A list of `PoolingRequestOutput` objects containing the
938
            pooled hidden states in the same order as the input prompts.
939
940

        Note:
941
            Using `prompts` and `prompt_token_ids` as keyword parameters is
942
            considered legacy and may be deprecated in the future. You should
943
            instead pass them via the `inputs` parameter.
944
        """
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

964
965
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
966
        if runner_type != "pooling":
967
968
969
970
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
971

972
973
974
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")
975

976
977
978
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
979

980
981
982
983
984
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
985

986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1001

1002
        self._validate_and_add_requests(
1003
            prompts=prompts,
1004
            params=pooling_params,
1005
            use_tqdm=use_tqdm,
1006
            lora_request=lora_request,
1007
1008
        )

1009
        outputs = self._run_engine(use_tqdm=use_tqdm)
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
1028

1029
1030
1031
1032
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1033
        truncate_prompt_tokens: Optional[int] = None,
1034
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1035
1036
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1037
1038
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1039
1040
1041
1042
1043
1044
1045
1046
1047
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1048
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1049
                for more details about the format of each prompt.
1050
1051
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1052
1053
1054
1055
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1056
1057
1058
            lora_request: LoRA request to use for generation, if any.

        Returns:
1059
            A list of `EmbeddingRequestOutput` objects containing the
1060
1061
            embedding vectors in the same order as the input prompts.
        """
1062
        if "embed" not in self.supported_tasks:
1063
1064
1065
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1066

1067
1068
1069
1070
1071
1072
1073
1074
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1075
1076
1077
1078
1079
1080
1081

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1082
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1083
1084
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1085
1086
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1087
1088
1089
1090
1091
1092
1093
1094
1095
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1096
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1097
                for more details about the format of each prompt.
1098
1099
1100
1101
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1102
            lora_request: LoRA request to use for generation, if any.
1103
1104
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1105
        Returns:
1106
            A list of `ClassificationRequestOutput` objects containing the
1107
1108
            embedding vectors in the same order as the input prompts.
        """
1109
        if "classify" not in self.supported_tasks:
1110
            raise ValueError(
1111
                "Classification API is not supported by this model. "
1112
                "Try converting the model using `--convert classify`.")
1113

1114
1115
1116
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1117
            pooling_params=pooling_params,
1118
1119
1120
            lora_request=lora_request,
            pooling_task="classify",
        )
1121
1122
1123

        return [ClassificationRequestOutput.from_base(item) for item in items]

1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1141
                for more details about the format of each prompt.
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1163
1164
1165
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1166
1167
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1168
        truncate_prompt_tokens: Optional[int] = None,
1169
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1170
        pooling_params: Optional[PoolingParams] = None,
1171
1172
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1173

1174
        encoded_output: list[PoolingRequestOutput] = self.encode(
1175
            text_1 + text_2,
1176
            truncate_prompt_tokens=truncate_prompt_tokens,
1177
1178
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1179
            pooling_params=pooling_params,
1180
1181
            pooling_task="embed",
        )
1182

1183
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1184
            0:len(text_1)]
1185
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1186
            len(text_1):]
1187
1188
1189
1190

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1191
1192
1193
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1194
1195
1196
1197
1198
1199
1200

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1201
        tokenizer: AnyTokenizer,
1202
1203
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1204
        truncate_prompt_tokens: Optional[int] = None,
1205
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1206
        pooling_params: Optional[PoolingParams] = None,
1207
1208
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1209
        model_config = self.llm_engine.model_config
1210
1211
1212

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1213
                "Score API is not supported for Mistral tokenizer")
1214

1215
1216
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1217

1218
1219
1220
1221
1222
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1223
        pooling_params_list = list[PoolingParams]()
1224

1225
        tokenization_kwargs: dict[str, Any] = {}
1226
1227

        _validate_truncation_size(model_config.max_model_len,
1228
                                  truncate_prompt_tokens, tokenization_kwargs)
1229

1230
        prompts = list[PromptType]()
1231

1232
1233
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1234
        model_config = self.llm_engine.model_config
1235

1236
1237
1238
1239
1240
1241
1242
1243
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )
1244

1245
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1246
1247
1248
1249
1250
1251
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)
1252

1253
            prompts.append(engine_prompt)
1254
1255

        self._validate_and_add_requests(
1256
            prompts=prompts,
1257
            params=pooling_params_list,
1258
            use_tqdm=use_tqdm,
1259
1260
1261
1262
1263
1264
1265
1266
1267
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1268
1269
    def score(
        self,
1270
1271
1272
1273
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1274
        /,
1275
        *,
1276
        truncate_prompt_tokens: Optional[int] = None,
1277
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1278
        pooling_params: Optional[PoolingParams] = None,
1279
1280
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1281
1282
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1283

1284
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1285
1286
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1287
        The input pairs are used to build a list of prompts for the
1288
1289
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1290
1291
1292
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1293
        appropriate multi-modal models. For multi-modal inputs, ensure the
1294
        prompt structure matches the model's expected input format.
1295
1296

        Args:
1297
1298
1299
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1300
                the `data_2` list.
1301
            data_2: The data to pair with the query to form the input to
1302
                the LLM. Can be text or multi-modal data. See [PromptType]
1303
                [vllm.inputs.PromptType] for more details about the format of
1304
                each prompt.
1305
1306
1307
1308
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1309
            lora_request: LoRA request to use for generation, if any.
1310
1311
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1312
        Returns:
1313
            A list of `ScoringRequestOutput` objects containing the
1314
1315
            generated scores in the same order as the input prompts.
        """
1316
1317
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1318
        if runner_type != "pooling":
1319
1320
1321
1322
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1323

1324
1325
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1326
            raise ValueError("Score API is not supported by this model. "
1327
1328
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1329

1330
        if (model_config.is_cross_encoder
1331
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1332
            raise ValueError("Score API is only enabled for num_labels == 1.")
1333
1334
1335
1336

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1337
        tokenizer = self.get_tokenizer()
1338

1339
        if not model_config.is_multimodal_model:
1340

1341
1342
1343
1344
            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1345
1346
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]
1373

1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1387

1388
        if model_config.is_cross_encoder:
1389
1390
1391
1392
1393
1394
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1395
                pooling_params,
1396
                lora_request)
1397
        else:
1398
1399
            return self._embedding_score(
                tokenizer,
1400
1401
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1402
1403
                truncate_prompt_tokens,
                use_tqdm,
1404
                pooling_params,
1405
                lora_request)
1406

1407
1408
1409
1410
1411
1412
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1413
1414
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1415

1416
1417
1418
1419
1420
1421
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1422
        Args:
1423
1424
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1425
                is forgotten. Level 1 sleep is good for sleeping and waking
1426
1427
1428
1429
1430
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1431
                sleep is good for sleeping and waking up the engine to run a
1432
                different model or update the model, where previous model
1433
                weights are not needed. It reduces CPU memory pressure.
1434
        """
1435
        self.reset_prefix_cache()
1436
1437
        self.llm_engine.sleep(level=level)

1438
    def wake_up(self, tags: Optional[list[str]] = None):
1439
        """
1440
1441
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1442

1443
        Args:
1444
1445
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1446
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1447
                wake_up should be called with all tags (or None) before the
1448
1449
1450
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1451

1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1466
1467
    def _validate_and_add_requests(
        self,
1468
        prompts: Union[PromptType, Sequence[PromptType]],
1469
1470
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1471
        *,
1472
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1473
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1474
        priority: Optional[list[int]] = None,
1475
    ) -> None:
1476
        if isinstance(prompts, (str, dict)):
1477
            # Convert a single prompt to a list.
1478
            prompts = [prompts]
1479

1480
        num_requests = len(prompts)
1481
        if isinstance(params, Sequence) and len(params) != num_requests:
1482
            raise ValueError("The lengths of prompts and params "
1483
                             "must be the same.")
1484
        if isinstance(lora_request,
1485
                      Sequence) and len(lora_request) != num_requests:
1486
1487
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1488

1489
        for sp in params if isinstance(params, Sequence) else (params, ):
1490
1491
1492
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1493

Zhuohan Li's avatar
Zhuohan Li committed
1494
        # Add requests to the engine.
1495
1496
        it = prompts
        if use_tqdm:
1497
1498
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1499

1500
1501
        model_config = self.llm_engine.model_config

1502
        for i, prompt in enumerate(it):
1503
1504
1505
1506
1507
1508
1509
1510

            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1511
            self._add_request(
1512
                prompt,
1513
                params[i] if isinstance(params, Sequence) else params,
1514
                tokenization_kwargs=tokenization_kwargs,
1515
1516
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1517
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1518
            )
1519

1520
    def _add_request(
nunjunj's avatar
nunjunj committed
1521
        self,
1522
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1523
        params: Union[SamplingParams, PoolingParams],
1524
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1525
        lora_request: Optional[LoRARequest] = None,
1526
        priority: int = 0,
1527
1528
    ) -> None:
        request_id = str(next(self.request_counter))
1529
1530
        self.llm_engine.add_request(
            request_id,
1531
            prompt,
1532
1533
            params,
            lora_request=lora_request,
1534
            tokenization_kwargs=tokenization_kwargs,
1535
            priority=priority,
nunjunj's avatar
nunjunj committed
1536
        )
1537

1538
    def _run_engine(
1539
1540
1541
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1542
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1543
1544
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1545
            num_requests = self.llm_engine.get_num_unfinished_requests()
1546
1547
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1548
1549
1550
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1551
1552
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1553
            )
1554

Zhuohan Li's avatar
Zhuohan Li committed
1555
        # Run the engine.
1556
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1557
1558
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1559
1560
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1561
            for output in step_outputs:
1562
                if output.finished:
1563
1564
                    outputs.append(output)
                    if use_tqdm:
1565
1566
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1567
                            n = len(output.outputs)
1568
                            assert output.prompt_token_ids is not None
1569
                            total_in_toks += len(output.prompt_token_ids) * n
1570
1571
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1572
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1573
1574
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1575
1576
1577
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1578
                            pbar.update(n)
1579
1580
                        else:
                            pbar.update(1)
1581
1582
                        if pbar.n == num_requests:
                            pbar.refresh()
1583

1584
1585
        if use_tqdm:
            pbar.close()
lizhigong's avatar
lizhigong committed
1586

1587
1588
1589
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1590
        return sorted(outputs, key=lambda x: int(x.request_id))