llm.py 71.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
luopl's avatar
luopl committed
13
from vllm.inputs import token_inputs
14
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
15
16
                              BeamSearchSequence,
                              create_sort_beams_key_function)
17
18
from vllm.config import (CompilationConfig, ModelDType,
                         StructuredOutputsConfig, TokenizerMode, is_init_field)
19
20
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
nunjunj's avatar
nunjunj committed
21
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
22
                                         ChatTemplateContentFormatOption,
23
24
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
25
26
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
27
28
# yapf conflicts with isort for this block
# yapf: disable
29
30
31
32
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
33
                                          compress_token_type_ids,
34
                                          get_score_prompt)
35
# yapf: enable
36
37
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
38
39
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
40
from vllm.logger import init_logger
41
from vllm.lora.request import LoRARequest
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
44
45
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
46
from vllm.plugins.io_processors import get_io_processor
47
from vllm.pooling_params import PoolingParams
48
49
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
50
from vllm.tasks import PoolingTask
51
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
52
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
53
from vllm.usage.usage_lib import UsageContext
54
from vllm.utils import Counter, Device, as_iter, is_list_of
55
from vllm.v1.engine.llm_engine import LLMEngine
56
from vllm.v1.sample.logits_processor import LogitsProcessor
57

58
import vllm.envs as envs
59

60

61
62
63
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

64
65
logger = init_logger(__name__)

66
67
_R = TypeVar("_R", default=Any)

68
69

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
73
74
75
76
77
78
79
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
80
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
81
82
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
83
84
85
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
86
87
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
88
89
90
91
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
luopl's avatar
luopl committed
92
        allowed_media_domains: If set, only media URLs that belong to this
93
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
101
        quantization: The method used to quantize the model weights. Currently,
102
            we support "awq", "gptq", and "fp8" (experimental).
103
104
105
106
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
107
108
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
109
110
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
111
112
113
114
115
116
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
117
118
119
120
121
122
123
124
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
            compared with using gpu_memory_memory_utilization. Note that
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
125
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
126
127
128
129
130
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
131
132
133
134
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
135
136
137
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
138
139
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
140
        hf_token: The token to use as HTTP bearer authorization for remote files
141
            . If `True`, will use the token generated when running
142
            `huggingface-cli login` (stored in `~/.huggingface`).
143
144
145
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
146
147
148
149
150
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
151
152
153
154
155
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
156
157
158
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
159
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
160

161
162
    Note:
        This class is intended to be used for offline inference. For online
163
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
164
    """
165
166
167
168

    def __init__(
        self,
        model: str,
169
        *,
170
171
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
172
        tokenizer: Optional[str] = None,
173
        #need change mode as "cpm" for 9g tokenizer
zhuwenwen's avatar
zhuwenwen committed
174
        # tokenizer_mode: TokenizerMode = "cpm",
175
        tokenizer_mode: TokenizerMode = "auto",
176
        skip_tokenizer_init: bool = False,
177
        trust_remote_code: bool = False,
178
        allowed_local_media_path: str = "",
179
        allowed_media_domains: Optional[list[str]] = None,
180
        tensor_parallel_size: int = 1,
181
182
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
183
        revision: Optional[str] = None,
184
        tokenizer_revision: Optional[str] = None,
185
        seed: Optional[int] = None,
186
        gpu_memory_utilization: float = 0.9,
187
        swap_space: float = 4,
188
        cpu_offload_gb: float = 0,
189
        enforce_eager: bool = False,
190
        disable_custom_all_reduce: bool = False,
191
        hf_token: Optional[Union[bool, str]] = None,
192
        hf_overrides: Optional[HfOverrides] = None,
193
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
194
        pooler_config: Optional[PoolerConfig] = None,
195
        override_pooler_config: Optional[PoolerConfig] = None,
196
197
        structured_outputs_config: Optional[Union[dict[
            str, Any], StructuredOutputsConfig]] = None,
198
        kv_cache_memory_bytes: Optional[int] = None,
199
200
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
201
202
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
203
        **kwargs: Any,
204
    ) -> None:
205
        """LLM constructor."""
206

207
208
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
209

210
211
212
213
214
215
216
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

217
218
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
219
            from vllm.config.kv_transfer import KVTransferConfig
220
221
222
223
224
225
226
227
228
229
230
231
232
233
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

234
235
236
        if hf_overrides is None:
            hf_overrides = {}

237
        if compilation_config is not None:
238
239
240
241
242
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
243
244
245
246
247
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
                    })
248
249
            else:
                compilation_config_instance = compilation_config
250
        else:
251
            compilation_config_instance = CompilationConfig()
252

253
254
255
256
257
258
259
260
261
262
263
264
265
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
                    })
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

Zhuohan Li's avatar
Zhuohan Li committed
266
        engine_args = EngineArgs(
267
            model=model,
268
269
            runner=runner,
            convert=convert,
270
            tokenizer=tokenizer,
271
            tokenizer_mode=tokenizer_mode,
272
            skip_tokenizer_init=skip_tokenizer_init,
273
            trust_remote_code=trust_remote_code,
274
            allowed_local_media_path=allowed_local_media_path,
275
            allowed_media_domains=allowed_media_domains,
276
277
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
278
            quantization=quantization,
279
            revision=revision,
280
            tokenizer_revision=tokenizer_revision,
281
282
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
283
            kv_cache_memory_bytes=kv_cache_memory_bytes,
284
            swap_space=swap_space,
285
            cpu_offload_gb=cpu_offload_gb,
286
            enforce_eager=enforce_eager,
287
            disable_custom_all_reduce=disable_custom_all_reduce,
288
            hf_token=hf_token,
289
            hf_overrides=hf_overrides,
290
            mm_processor_kwargs=mm_processor_kwargs,
291
            pooler_config=pooler_config,
292
            override_pooler_config=override_pooler_config,
293
            structured_outputs_config=structured_outputs_instance,
294
            compilation_config=compilation_config_instance,
295
            logits_processors=logits_processors,
296
297
            **kwargs,
        )
298

299
300
        log_non_default_args(engine_args)

301
        # Create the Engine (autoselects V0 vs V1)
302
303
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
304
        self.engine_class = type(self.llm_engine)
305

306
        self.request_counter = Counter()
307
        self.default_sampling_params: Union[dict[str, Any], None] = None
308

309
        supported_tasks = self.llm_engine.get_supported_tasks()  # type: ignore
310
311
312
313
314

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

315
316
317
318
319
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

320
321
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
322
323

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
324
325
326
327
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
328
            self.llm_engine.tokenizer = tokenizer
329
        else:
330
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
331

332
    def get_default_sampling_params(self) -> SamplingParams:
333
334
335
336
337
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
338
339
        return SamplingParams()

340
341
342
343
344
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
345
        *,
346
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
347
348
349
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
350
351
        """Generates the completions for the input prompts.

352
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
353
354
355
356
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
357
            prompts: The prompts to the LLM. You may pass a sequence of prompts
358
                for batch inference. See [PromptType][vllm.inputs.PromptType]
359
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
360
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
361
362
363
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
364
                prompts and it is paired one by one with the prompt.
365
366
367
368
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
369
            lora_request: LoRA request to use for generation, if any.
370
371
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
372
373

        Returns:
374
            A list of `RequestOutput` objects containing the
375
            generated completions in the same order as the input prompts.
376
377

        Note:
378
            Using `prompts` and `prompt_token_ids` as keyword parameters is
379
            considered legacy and may be deprecated in the future. You should
380
            instead pass them via the `inputs` parameter.
381
        """
382
383
384
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
385
386
387
388
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
389

390
391
        if sampling_params is None:
            # Use default sampling params.
392
            sampling_params = self.get_default_sampling_params()
393

394
395
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
396
            prompts, lora_request)
397

398
        self._validate_and_add_requests(
399
            prompts=prompts,
400
            params=sampling_params,
401
            use_tqdm=use_tqdm,
402
            lora_request=lora_request,
403
404
            priority=priority,
        )
405

406
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
407
        return self.engine_class.validate_outputs(outputs, RequestOutput)
408

409
    def _get_modality_specific_lora_reqs(
410
            self, prompts: Union[PromptType, Sequence[PromptType]],
411
412
413
414
415
416
417
418
419
420
421
422
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

423
424
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
425

426
        optional_loras = ([lora_request] * len(prompts)
427
428
429
430
431
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
432
                prompt,
433
434
                opt_lora_req,
                lora_config.default_mm_loras,
435
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
436
437
        ]

438
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
439
440
441
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
442
443
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
444
445
            return lora_request

446
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
447

448
449
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

484
    def collective_rpc(self,
485
                       method: Union[str, Callable[..., _R]],
486
                       timeout: Optional[float] = None,
487
488
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
489
490
491
492
493
494
495
496
497
498
499
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
500
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
501
502
503
504
505
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
506

507
508
509
510
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
511
512

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
513
514

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
515
        """
516
517
        Run a function directly on the model inside each worker,
        returning the result for each of them.
518
519
520
521
522
523

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
524
        """
525
        return self.llm_engine.apply_model(func)
526

527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

543
544
    def beam_search(
        self,
545
        prompts: list[Union[TokensPrompt, TextPrompt]],
546
        params: BeamSearchParams,
547
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
548
        use_tqdm: bool = False,
549
        concurrency_limit: Optional[int] = None,
550
    ) -> list[BeamSearchOutput]:
551
552
553
554
555
556
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
557
            params: The beam search parameters.
558
            lora_request: LoRA request to use for generation, if any.
559
            use_tqdm: Whether to use tqdm to display the progress bar.
560
561
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
562
        """
563
564
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
565
566
567
568
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
569
570
        length_penalty = params.length_penalty

571
572
573
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

574
575
576
577
578
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
579

580
581
582
583
584
585
586
587
588
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

589
590
591
592
593
594
595
596
597
598
599
600
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
601

602
603
604
605
606
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
607
                                            temperature=temperature)
608
        instances: list[BeamSearchInstance] = []
609

610
        for lora_req, prompt in zip(lora_requests, prompts):
611
612
613
614
615
616
617
618
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

619
620
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
621
622
623
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
624

625
            instances.append(
626
627
628
629
630
631
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
632

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
706
707
708
709
710

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
711
                                      key=sort_beams_key,
712
713
714
715
716
717
718
719
720
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

721
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
722
        self,
723
724
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
725
        chat_template: Optional[str] = None,
726
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
727
        add_generation_prompt: bool = True,
728
        continue_final_message: bool = False,
729
        tools: Optional[list[dict[str, Any]]] = None,
730
        chat_template_kwargs: Optional[dict[str, Any]] = None,
731
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
732
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
733
        """
734
735
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
736

737
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
738
        Returns:
739
740
741
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
742
        """
743
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
744

745
746
        # Handle multi and single conversations
        if is_list_of(messages, list):
747
748
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
749
                                    messages)
750
        else:
751
            # messages is list[...]
752
            list_of_messages = [
753
                cast(list[ChatCompletionMessageParam], messages)
754
            ]
755

756
        tokenizer = self.get_tokenizer()
757
758
759
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
760
            tools,
761
762
            chat_template_content_format,
            tokenizer,
763
            model_config=model_config,
764
765
        )

766
767
768
769
770
771
772
773
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

774
        prompts: list[TokensPrompt] = []
775
776

        for msgs in list_of_messages:
777
778
779
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
780
            conversation, mm_data, mm_uuids = parse_chat_messages(
781
782
783
784
785
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
786
787

            if isinstance(tokenizer, MistralTokenizer):
788
                prompt_token_ids = apply_mistral_chat_template(
789
790
                    tokenizer,
                    messages=msgs,
791
                    **_chat_template_kwargs,
792
793
                )
            else:
794
                prompt_str = apply_hf_chat_template(
795
                    tokenizer=tokenizer,
796
                    conversation=conversation,
797
                    model_config=model_config,
798
                    **_chat_template_kwargs,
799
                )
800
801
802
803
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
804

805
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
806
807
808
809

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

810
811
812
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

813
814
815
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

816
            prompts.append(prompt)
817

818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
        return prompts

    def chat(
        self,
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
        sampling_params: Optional[Union[SamplingParams,
                                        list[SamplingParams]]] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
898
        return self.generate(
899
            prompts,
900
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
901
902
903
904
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

905
906
    def encode(
        self,
907
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
908
909
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
910
        *,
911
        truncate_prompt_tokens: Optional[int] = None,
912
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
913
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
914
        pooling_task: PoolingTask = "encode",
915
        tokenization_kwargs: Optional[dict[str, Any]] = None,
916
    ) -> list[PoolingRequestOutput]:
917
918
        """Apply pooling to the hidden states corresponding to the input
        prompts.
919

920
        This class automatically batches the given prompts, considering
921
922
923
924
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
925
            prompts: The prompts to the LLM. You may pass a sequence of prompts
926
                for batch inference. See [PromptType][vllm.inputs.PromptType]
927
                for more details about the format of each prompt.
928
929
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
930
931
932
933
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
934
            lora_request: LoRA request to use for generation, if any.
935
            pooling_task: Override the pooling task to use.
936
937
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
938
939

        Returns:
940
            A list of `PoolingRequestOutput` objects containing the
941
            pooled hidden states in the same order as the input prompts.
942
943

        Note:
944
            Using `prompts` and `prompt_token_ids` as keyword parameters is
945
            considered legacy and may be deprecated in the future. You should
946
            instead pass them via the `inputs` parameter.
947
        """
948
949
950
951

        if self.supported_tasks == ["encode"] and pooling_task is None:
            pooling_task = "encode"

952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

971
972
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
973
        if runner_type != "pooling":
974
975
976
977
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
978

979
980
981
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")
982

983
984
985
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
986

luopl's avatar
luopl committed
987
988
989
        if prompts["qfeat"] is not None:
            pooling_params.qfeat = prompts["qfeat"]

990
991
992
993
994
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
995

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1011

1012
        self._validate_and_add_requests(
1013
            prompts=prompts,
1014
            params=pooling_params,
1015
            use_tqdm=use_tqdm,
1016
            lora_request=lora_request,
1017
1018
        )

1019
        outputs = self._run_engine(use_tqdm=use_tqdm)
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
1038

1039
1040
1041
1042
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1043
        truncate_prompt_tokens: Optional[int] = None,
1044
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1045
1046
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1047
1048
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1049
1050
1051
1052
1053
1054
1055
1056
1057
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1058
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1059
                for more details about the format of each prompt.
1060
1061
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1062
1063
1064
1065
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1066
1067
1068
            lora_request: LoRA request to use for generation, if any.

        Returns:
1069
            A list of `EmbeddingRequestOutput` objects containing the
1070
1071
            embedding vectors in the same order as the input prompts.
        """
1072
        if "embed" not in self.supported_tasks:
1073
1074
1075
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1076

1077
1078
1079
1080
1081
1082
1083
1084
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1085
1086
1087
1088
1089
1090
1091

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1092
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1093
1094
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1095
1096
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1097
1098
1099
1100
1101
1102
1103
1104
1105
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1106
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1107
                for more details about the format of each prompt.
1108
1109
1110
1111
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1112
            lora_request: LoRA request to use for generation, if any.
1113
1114
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1115
        Returns:
1116
            A list of `ClassificationRequestOutput` objects containing the
1117
1118
            embedding vectors in the same order as the input prompts.
        """
1119
        if "classify" not in self.supported_tasks:
1120
            raise ValueError(
1121
                "Classification API is not supported by this model. "
1122
                "Try converting the model using `--convert classify`.")
1123

1124
1125
1126
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1127
            pooling_params=pooling_params,
1128
1129
1130
            lora_request=lora_request,
            pooling_task="classify",
        )
1131
1132
1133

        return [ClassificationRequestOutput.from_base(item) for item in items]

1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1151
                for more details about the format of each prompt.
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1173
1174
1175
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1176
1177
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1178
        truncate_prompt_tokens: Optional[int] = None,
1179
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1180
        pooling_params: Optional[PoolingParams] = None,
1181
1182
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1183

1184
        encoded_output: list[PoolingRequestOutput] = self.encode(
1185
            text_1 + text_2,
1186
            truncate_prompt_tokens=truncate_prompt_tokens,
1187
1188
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1189
            pooling_params=pooling_params,
1190
1191
            pooling_task="embed",
        )
1192

1193
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1194
            0:len(text_1)]
1195
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1196
            len(text_1):]
1197
1198
1199
1200

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1201
1202
1203
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1204
1205
1206
1207
1208
1209
1210

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1211
        tokenizer: AnyTokenizer,
1212
1213
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1214
        truncate_prompt_tokens: Optional[int] = None,
1215
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1216
        pooling_params: Optional[PoolingParams] = None,
1217
1218
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1219
        model_config = self.llm_engine.model_config
1220
1221
1222

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1223
                "Score API is not supported for Mistral tokenizer")
1224

1225
1226
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1227

1228
1229
1230
1231
1232
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1233
        pooling_params_list = list[PoolingParams]()
1234

1235
        tokenization_kwargs: dict[str, Any] = {}
1236
1237

        _validate_truncation_size(model_config.max_model_len,
1238
                                  truncate_prompt_tokens, tokenization_kwargs)
1239

1240
        prompts = list[PromptType]()
1241

1242
1243
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1244
        model_config = self.llm_engine.model_config
1245

1246
1247
1248
1249
1250
1251
1252
1253
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )
1254

1255
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1256
1257
1258
1259
1260
1261
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)
1262

1263
            prompts.append(engine_prompt)
1264
1265

        self._validate_and_add_requests(
1266
            prompts=prompts,
1267
            params=pooling_params_list,
1268
            use_tqdm=use_tqdm,
1269
1270
1271
1272
1273
1274
1275
1276
1277
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1278
1279
    def score(
        self,
1280
1281
1282
1283
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1284
        /,
1285
        *,
1286
        truncate_prompt_tokens: Optional[int] = None,
1287
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1288
        pooling_params: Optional[PoolingParams] = None,
1289
1290
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1291
1292
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1293

1294
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1295
1296
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1297
        The input pairs are used to build a list of prompts for the
1298
1299
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1300
1301
1302
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1303
        appropriate multi-modal models. For multi-modal inputs, ensure the
1304
        prompt structure matches the model's expected input format.
1305
1306

        Args:
1307
1308
1309
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1310
                the `data_2` list.
1311
            data_2: The data to pair with the query to form the input to
1312
                the LLM. Can be text or multi-modal data. See [PromptType]
1313
                [vllm.inputs.PromptType] for more details about the format of
1314
                each prompt.
1315
1316
1317
1318
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1319
            lora_request: LoRA request to use for generation, if any.
1320
1321
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1322
        Returns:
1323
            A list of `ScoringRequestOutput` objects containing the
1324
1325
            generated scores in the same order as the input prompts.
        """
1326
1327
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1328
        if runner_type != "pooling":
1329
1330
1331
1332
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1333

1334
1335
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1336
            raise ValueError("Score API is not supported by this model. "
1337
1338
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1339

1340
        if (model_config.is_cross_encoder
1341
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1342
            raise ValueError("Score API is only enabled for num_labels == 1.")
1343
1344
1345
1346

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1347
        tokenizer = self.get_tokenizer()
1348

1349
        if not model_config.is_multimodal_model:
1350

1351
1352
1353
1354
            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1355
1356
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]
1383

1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1397

1398
        if model_config.is_cross_encoder:
1399
1400
1401
1402
1403
1404
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1405
                pooling_params,
1406
                lora_request)
1407
        else:
1408
1409
            return self._embedding_score(
                tokenizer,
1410
1411
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1412
1413
                truncate_prompt_tokens,
                use_tqdm,
1414
                pooling_params,
1415
                lora_request)
1416

1417
1418
1419
1420
1421
1422
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1423
1424
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1425

1426
1427
1428
1429
1430
1431
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1432
        Args:
1433
1434
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1435
                is forgotten. Level 1 sleep is good for sleeping and waking
1436
1437
1438
1439
1440
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1441
                sleep is good for sleeping and waking up the engine to run a
1442
                different model or update the model, where previous model
1443
                weights are not needed. It reduces CPU memory pressure.
1444
        """
1445
        self.reset_prefix_cache()
1446
1447
        self.llm_engine.sleep(level=level)

1448
    def wake_up(self, tags: Optional[list[str]] = None):
1449
        """
1450
1451
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1452

1453
        Args:
1454
1455
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1456
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1457
                wake_up should be called with all tags (or None) before the
1458
1459
1460
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1461

1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1474
1475
    def _validate_and_add_requests(
        self,
1476
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
1477
1478
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1479
        *,
1480
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1481
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1482
        priority: Optional[list[int]] = None,
1483
    ) -> None:
1484
        if isinstance(prompts, (str, dict)):
1485
            # Convert a single prompt to a list.
1486
            prompts = [prompts]  # type: ignore[list-item]
1487

1488
        num_requests = len(prompts)
1489
        if isinstance(params, Sequence) and len(params) != num_requests:
1490
            raise ValueError("The lengths of prompts and params "
1491
                             "must be the same.")
1492
        if isinstance(lora_request,
1493
                      Sequence) and len(lora_request) != num_requests:
1494
1495
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1496

1497
        for sp in params if isinstance(params, Sequence) else (params, ):
1498
1499
1500
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1501

Zhuohan Li's avatar
Zhuohan Li committed
1502
        # Add requests to the engine.
1503
1504
        it = prompts
        if use_tqdm:
1505
1506
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1507

1508
1509
        model_config = self.llm_engine.model_config

1510
        for i, prompt in enumerate(it):
1511

1512
1513
1514
1515
1516
            if isinstance(prompt, dict):
                self._validate_mm_data_and_uuids(
                    prompt.get("multi_modal_data"),
                    prompt.get("multi_modal_uuids"))

1517
1518
1519
1520
1521
1522
1523
            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1524
            self._add_request(
1525
                prompt,
1526
                params[i] if isinstance(params, Sequence) else params,
1527
                tokenization_kwargs=tokenization_kwargs,
1528
1529
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1530
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1531
            )
1532

1533
1534
1535
1536
1537
1538
1539
    def _validate_mm_data_and_uuids(
            self,
            multi_modal_data: Optional[Any],  # MultiModalDataDict
            multi_modal_uuids: Optional[Any],  # MultiModalUUIDDict
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1540
        then its corresponding UUID must be set.
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
                        if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[  # noqa: E501
                                modality] is None:
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
                                f"but UUID is not provided")
                        else:
                            if len(
                                    multi_modal_uuids[modality]
                            ) <= i or multi_modal_uuids[modality][i] is None:
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
                                    f"but UUID is not provided")
            else:
                if data is None and (multi_modal_uuids is None
                                     or modality not in multi_modal_uuids
                                     or multi_modal_uuids[modality] is None):
                    raise ValueError(f"Multi-modal data for {modality} is None"
                                     f" but UUID is not provided")

1568
    def _add_request(
nunjunj's avatar
nunjunj committed
1569
        self,
1570
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1571
        params: Union[SamplingParams, PoolingParams],
1572
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1573
        lora_request: Optional[LoRARequest] = None,
1574
        priority: int = 0,
1575
1576
    ) -> None:
        request_id = str(next(self.request_counter))
1577
1578
        self.llm_engine.add_request(
            request_id,
1579
            prompt,
1580
1581
            params,
            lora_request=lora_request,
1582
            tokenization_kwargs=tokenization_kwargs,
1583
            priority=priority,
nunjunj's avatar
nunjunj committed
1584
        )
1585

1586
    def _run_engine(
1587
1588
1589
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1590
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1591
1592
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1593
            num_requests = self.llm_engine.get_num_unfinished_requests()
1594
1595
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1596
1597
1598
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1599
1600
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1601
            )
1602

Zhuohan Li's avatar
Zhuohan Li committed
1603
        # Run the engine.
1604
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1605
1606
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1607
1608
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1609
            for output in step_outputs:
1610
                if output.finished:
1611
1612
                    outputs.append(output)
                    if use_tqdm:
1613
1614
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1615
                            n = len(output.outputs)
1616
                            assert output.prompt_token_ids is not None
1617
                            total_in_toks += len(output.prompt_token_ids) * n
1618
1619
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1620
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1621
1622
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1623
1624
1625
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1626
                            pbar.update(n)
1627
1628
                        else:
                            pbar.update(1)
1629
1630
                        if pbar.n == num_requests:
                            pbar.refresh()
1631

1632
1633
        if use_tqdm:
            pbar.close()
lizhigong's avatar
lizhigong committed
1634

1635
1636
1637
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1638
        return sorted(outputs, key=lambda x: int(x.request_id))