llm.py 71.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
15
16
                              BeamSearchSequence,
                              create_sort_beams_key_function)
17
18
from vllm.config import (CompilationConfig, ModelDType,
                         StructuredOutputsConfig, TokenizerMode, is_init_field)
19
20
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
nunjunj's avatar
nunjunj committed
21
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
22
                                         ChatTemplateContentFormatOption,
23
24
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
25
26
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
27
28
# yapf conflicts with isort for this block
# yapf: disable
29
30
31
32
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
33
                                          compress_token_type_ids,
34
                                          get_score_prompt)
35
# yapf: enable
36
37
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
38
39
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
40
from vllm.logger import init_logger
41
from vllm.lora.request import LoRARequest
42
from vllm.model_executor.layers.quantization import QuantizationMethods
43
44
45
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
46
from vllm.plugins.io_processors import get_io_processor
47
from vllm.pooling_params import PoolingParams
48
49
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
50
from vllm.tasks import PoolingTask
51
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
52
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
53
from vllm.usage.usage_lib import UsageContext
54
from vllm.utils import Counter, Device, as_iter, is_list_of
55
from vllm.v1.engine.llm_engine import LLMEngine
56
from vllm.v1.sample.logits_processor import LogitsProcessor
57

58
59
60
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

61
62
logger = init_logger(__name__)

63
64
_R = TypeVar("_R", default=Any)

65
66

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72
73
74
75
76
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
77
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
78
79
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
80
81
82
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
83
84
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
85
86
87
88
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
96
        quantization: The method used to quantize the model weights. Currently,
97
            we support "awq", "gptq", and "fp8" (experimental).
98
99
100
101
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
102
103
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
104
105
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
106
107
108
109
110
111
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
112
113
114
115
116
117
118
119
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
            compared with using gpu_memory_memory_utilization. Note that
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
120
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
121
122
123
124
125
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
126
127
128
129
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
130
131
132
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
133
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
134
            When a sequence has context length larger than this, we fall back
135
136
137
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
138
139
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
140
141
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
142
        hf_token: The token to use as HTTP bearer authorization for remote files
143
            . If `True`, will use the token generated when running
144
            `huggingface-cli login` (stored in `~/.huggingface`).
145
146
147
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
148
149
150
151
152
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
153
154
155
156
157
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
158
159
160
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
161
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
162

163
164
    Note:
        This class is intended to be used for offline inference. For online
165
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
166
    """
167
168
169
170

    def __init__(
        self,
        model: str,
171
        *,
172
173
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
174
        tokenizer: Optional[str] = None,
175
        tokenizer_mode: TokenizerMode = "auto",
176
        skip_tokenizer_init: bool = False,
177
        trust_remote_code: bool = False,
178
        allowed_local_media_path: str = "",
179
        tensor_parallel_size: int = 1,
180
181
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
182
        revision: Optional[str] = None,
183
        tokenizer_revision: Optional[str] = None,
184
        seed: Optional[int] = None,
185
        gpu_memory_utilization: float = 0.9,
186
        swap_space: float = 4,
187
        cpu_offload_gb: float = 0,
188
        enforce_eager: bool = False,
189
        max_seq_len_to_capture: int = 8192,
190
        disable_custom_all_reduce: bool = False,
191
        disable_async_output_proc: bool = False,
192
        hf_token: Optional[Union[bool, str]] = None,
193
        hf_overrides: Optional[HfOverrides] = None,
194
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
195
        pooler_config: Optional[PoolerConfig] = None,
196
        override_pooler_config: Optional[PoolerConfig] = None,
197
198
        structured_outputs_config: Optional[Union[dict[
            str, Any], StructuredOutputsConfig]] = None,
199
        kv_cache_memory_bytes: Optional[int] = None,
200
201
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
202
203
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
204
        **kwargs: Any,
205
    ) -> None:
206
        """LLM constructor."""
207

208
209
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
210

211
212
213
214
215
216
217
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

218
219
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
220
            from vllm.config.kv_transfer import KVTransferConfig
221
222
223
224
225
226
227
228
229
230
231
232
233
234
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

235
236
237
        if hf_overrides is None:
            hf_overrides = {}

238
        if compilation_config is not None:
239
240
241
242
243
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
244
245
246
247
248
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
                    })
249
250
            else:
                compilation_config_instance = compilation_config
251
        else:
252
            compilation_config_instance = CompilationConfig()
253

254
255
256
257
258
259
260
261
262
263
264
265
266
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
                    })
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

Zhuohan Li's avatar
Zhuohan Li committed
267
        engine_args = EngineArgs(
268
            model=model,
269
270
            runner=runner,
            convert=convert,
271
            tokenizer=tokenizer,
272
            tokenizer_mode=tokenizer_mode,
273
            skip_tokenizer_init=skip_tokenizer_init,
274
            trust_remote_code=trust_remote_code,
275
            allowed_local_media_path=allowed_local_media_path,
276
277
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
278
            quantization=quantization,
279
            revision=revision,
280
            tokenizer_revision=tokenizer_revision,
281
282
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
283
            kv_cache_memory_bytes=kv_cache_memory_bytes,
284
            swap_space=swap_space,
285
            cpu_offload_gb=cpu_offload_gb,
286
            enforce_eager=enforce_eager,
287
            max_seq_len_to_capture=max_seq_len_to_capture,
288
            disable_custom_all_reduce=disable_custom_all_reduce,
289
            disable_async_output_proc=disable_async_output_proc,
290
            hf_token=hf_token,
291
            hf_overrides=hf_overrides,
292
            mm_processor_kwargs=mm_processor_kwargs,
293
            pooler_config=pooler_config,
294
            override_pooler_config=override_pooler_config,
295
            structured_outputs_config=structured_outputs_instance,
296
            compilation_config=compilation_config_instance,
297
            logits_processors=logits_processors,
298
299
            **kwargs,
        )
300

301
302
        log_non_default_args(engine_args)

303
304
305
306
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
307

308
        self.request_counter = Counter()
309
        self.default_sampling_params: Union[dict[str, Any], None] = None
310

311
        supported_tasks = self.llm_engine.get_supported_tasks()  # type: ignore
312
313
314
315
316

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

317
318
319
320
321
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

322
323
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
324
325

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
326
327
328
329
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
330
            self.llm_engine.tokenizer = tokenizer
331
        else:
332
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
333

334
    def get_default_sampling_params(self) -> SamplingParams:
335
336
337
338
339
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
340
341
        return SamplingParams()

342
343
344
345
346
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
347
        *,
348
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
349
350
351
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
352
353
        """Generates the completions for the input prompts.

354
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
355
356
357
358
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
359
            prompts: The prompts to the LLM. You may pass a sequence of prompts
360
                for batch inference. See [PromptType][vllm.inputs.PromptType]
361
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
362
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
363
364
365
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
366
                prompts and it is paired one by one with the prompt.
367
368
369
370
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
371
            lora_request: LoRA request to use for generation, if any.
372
373
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
374
375

        Returns:
376
            A list of `RequestOutput` objects containing the
377
            generated completions in the same order as the input prompts.
378

379
380
381
382
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
383
        """
384
385
386
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
387
388
389
390
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
391

392
393
        if sampling_params is None:
            # Use default sampling params.
394
            sampling_params = self.get_default_sampling_params()
395

396
397
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
398
            prompts, lora_request)
399

400
        self._validate_and_add_requests(
401
            prompts=prompts,
402
            params=sampling_params,
403
            use_tqdm=use_tqdm,
404
            lora_request=lora_request,
405
406
            priority=priority,
        )
407

408
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
409
        return self.engine_class.validate_outputs(outputs, RequestOutput)
410

411
    def _get_modality_specific_lora_reqs(
412
            self, prompts: Union[PromptType, Sequence[PromptType]],
413
414
415
416
417
418
419
420
421
422
423
424
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

425
426
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
427

428
        optional_loras = ([lora_request] * len(prompts)
429
430
431
432
433
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
434
                prompt,
435
436
                opt_lora_req,
                lora_config.default_mm_loras,
437
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
438
439
        ]

440
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
441
442
443
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
444
445
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
446
447
            return lora_request

448
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
449

450
451
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

486
    def collective_rpc(self,
487
                       method: Union[str, Callable[..., _R]],
488
                       timeout: Optional[float] = None,
489
490
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
491
492
493
494
495
496
497
498
499
500
501
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
502
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
503
504
505
506
507
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
508

509
510
511
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
512
        """
513
514

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
515
516

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
517
        """
518
519
        Run a function directly on the model inside each worker,
        returning the result for each of them.
520
521
522
523
524
525

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
526
        """
527
        return self.llm_engine.apply_model(func)
528

529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

545
546
    def beam_search(
        self,
547
        prompts: list[Union[TokensPrompt, TextPrompt]],
548
        params: BeamSearchParams,
549
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
550
        use_tqdm: bool = False,
551
        concurrency_limit: Optional[int] = None,
552
    ) -> list[BeamSearchOutput]:
553
554
555
556
557
558
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
559
            params: The beam search parameters.
560
            lora_request: LoRA request to use for generation, if any.
561
            use_tqdm: Whether to use tqdm to display the progress bar.
562
563
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
564
        """
565
566
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
567
568
569
570
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
571
572
        length_penalty = params.length_penalty

573
574
575
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

576
577
578
579
580
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
581

582
583
584
585
586
587
588
589
590
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

591
592
593
594
595
596
597
598
599
600
601
602
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
603

604
605
606
607
608
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
609
                                            temperature=temperature)
610
        instances: list[BeamSearchInstance] = []
611

612
        for lora_req, prompt in zip(lora_requests, prompts):
613
614
615
616
617
618
619
620
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

621
622
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
623
624
625
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
626

627
            instances.append(
628
629
630
631
632
633
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
634

635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
708
709
710
711
712

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
713
                                      key=sort_beams_key,
714
715
716
717
718
719
720
721
722
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

723
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
724
        self,
725
726
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
727
        chat_template: Optional[str] = None,
728
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
729
        add_generation_prompt: bool = True,
730
        continue_final_message: bool = False,
731
        tools: Optional[list[dict[str, Any]]] = None,
732
        chat_template_kwargs: Optional[dict[str, Any]] = None,
733
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
734
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
735
        """
736
737
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
738

739
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
740
        Returns:
741
742
743
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
744
        """
745
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
746

747
748
        # Handle multi and single conversations
        if is_list_of(messages, list):
749
750
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
751
                                    messages)
752
        else:
753
            # messages is list[...]
754
            list_of_messages = [
755
                cast(list[ChatCompletionMessageParam], messages)
756
            ]
757

758
        tokenizer = self.get_tokenizer()
759
760
761
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
762
            tools,
763
764
            chat_template_content_format,
            tokenizer,
765
            model_config=model_config,
766
767
        )

768
769
770
771
772
773
774
775
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

776
        prompts: list[TokensPrompt] = []
777
778

        for msgs in list_of_messages:
779
780
781
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
782
            conversation, mm_data, mm_uuids = parse_chat_messages(
783
784
785
786
787
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
788
789

            if isinstance(tokenizer, MistralTokenizer):
790
                prompt_token_ids = apply_mistral_chat_template(
791
792
                    tokenizer,
                    messages=msgs,
793
                    **_chat_template_kwargs,
794
795
                )
            else:
796
                prompt_str = apply_hf_chat_template(
797
                    tokenizer=tokenizer,
798
                    conversation=conversation,
799
                    model_config=model_config,
800
                    **_chat_template_kwargs,
801
                )
802
803
804
805
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
806

807
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
808
809
810
811

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

812
813
814
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

815
816
817
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

818
            prompts.append(prompt)
819

820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
        return prompts

    def chat(
        self,
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
        sampling_params: Optional[Union[SamplingParams,
                                        list[SamplingParams]]] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
900
        return self.generate(
901
            prompts,
902
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
903
904
905
906
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

907
908
    def encode(
        self,
909
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
910
911
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
912
        *,
913
        truncate_prompt_tokens: Optional[int] = None,
914
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
915
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
916
        pooling_task: PoolingTask = "encode",
917
        tokenization_kwargs: Optional[dict[str, Any]] = None,
918
    ) -> list[PoolingRequestOutput]:
919
920
        """Apply pooling to the hidden states corresponding to the input
        prompts.
921

922
        This class automatically batches the given prompts, considering
923
924
925
926
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
927
            prompts: The prompts to the LLM. You may pass a sequence of prompts
928
                for batch inference. See [PromptType][vllm.inputs.PromptType]
929
                for more details about the format of each prompt.
930
931
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
932
933
934
935
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
936
            lora_request: LoRA request to use for generation, if any.
937
            pooling_task: Override the pooling task to use.
938
939
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
940
941

        Returns:
942
            A list of `PoolingRequestOutput` objects containing the
943
            pooled hidden states in the same order as the input prompts.
944

945
946
947
948
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
949
        """
950
951
952
953

        if self.supported_tasks == ["encode"] and pooling_task is None:
            pooling_task = "encode"

954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

973
974
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
975
        if runner_type != "pooling":
976
977
978
979
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
980

981
982
983
984
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")

985
986
987
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
988

989
990
991
992
993
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
994

995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1011
        self._validate_and_add_requests(
1012
            prompts=prompts,
1013
            params=pooling_params,
1014
            use_tqdm=use_tqdm,
1015
            lora_request=lora_request,
1016
1017
        )

1018
        outputs = self._run_engine(use_tqdm=use_tqdm)
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
1037

1038
1039
1040
1041
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1042
        truncate_prompt_tokens: Optional[int] = None,
1043
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1044
1045
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1046
1047
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1048
1049
1050
1051
1052
1053
1054
1055
1056
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1057
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1058
                for more details about the format of each prompt.
1059
1060
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1061
1062
1063
1064
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1065
1066
1067
            lora_request: LoRA request to use for generation, if any.

        Returns:
1068
            A list of `EmbeddingRequestOutput` objects containing the
1069
1070
            embedding vectors in the same order as the input prompts.
        """
1071
        if "embed" not in self.supported_tasks:
1072
1073
1074
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1075

1076
1077
1078
1079
1080
1081
1082
1083
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1084
1085
1086
1087
1088
1089
1090

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1091
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1092
1093
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1094
1095
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1096
1097
1098
1099
1100
1101
1102
1103
1104
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1105
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1106
                for more details about the format of each prompt.
1107
1108
1109
1110
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1111
            lora_request: LoRA request to use for generation, if any.
1112
1113
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1114
        Returns:
1115
            A list of `ClassificationRequestOutput` objects containing the
1116
1117
            embedding vectors in the same order as the input prompts.
        """
1118
        if "classify" not in self.supported_tasks:
1119
            raise ValueError(
1120
                "Classification API is not supported by this model. "
1121
                "Try converting the model using `--convert classify`.")
1122

1123
1124
1125
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1126
            pooling_params=pooling_params,
1127
1128
1129
            lora_request=lora_request,
            pooling_task="classify",
        )
1130
1131
1132

        return [ClassificationRequestOutput.from_base(item) for item in items]

1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1150
                for more details about the format of each prompt.
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1172
1173
1174
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1175
1176
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1177
        truncate_prompt_tokens: Optional[int] = None,
1178
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1179
        pooling_params: Optional[PoolingParams] = None,
1180
1181
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1182

1183
        encoded_output: list[PoolingRequestOutput] = self.encode(
1184
            text_1 + text_2,
1185
            truncate_prompt_tokens=truncate_prompt_tokens,
1186
1187
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1188
            pooling_params=pooling_params,
1189
1190
            pooling_task="embed",
        )
1191

1192
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1193
            0:len(text_1)]
1194
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1195
            len(text_1):]
1196
1197
1198
1199

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1200
1201
1202
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1203
1204
1205
1206
1207
1208
1209

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1210
        tokenizer: AnyTokenizer,
1211
1212
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1213
        truncate_prompt_tokens: Optional[int] = None,
1214
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1215
        pooling_params: Optional[PoolingParams] = None,
1216
1217
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1218
        model_config = self.llm_engine.model_config
1219
1220
1221

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1222
                "Score API is not supported for Mistral tokenizer")
1223

1224
1225
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1226

1227
1228
1229
1230
1231
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1232
        pooling_params_list = list[PoolingParams]()
1233

1234
        tokenization_kwargs: dict[str, Any] = {}
1235
1236

        _validate_truncation_size(model_config.max_model_len,
1237
                                  truncate_prompt_tokens, tokenization_kwargs)
1238

1239
        prompts = list[PromptType]()
1240

1241
1242
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1243
        model_config = self.llm_engine.model_config
1244

1245
1246
1247
1248
1249
1250
1251
1252
1253
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1254
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1255
1256
1257
1258
1259
1260
1261
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1262
            prompts.append(engine_prompt)
1263
1264

        self._validate_and_add_requests(
1265
            prompts=prompts,
1266
            params=pooling_params_list,
1267
            use_tqdm=use_tqdm,
1268
1269
1270
1271
1272
1273
1274
1275
1276
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1277
1278
    def score(
        self,
1279
1280
1281
1282
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1283
        /,
1284
        *,
1285
        truncate_prompt_tokens: Optional[int] = None,
1286
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1287
        pooling_params: Optional[PoolingParams] = None,
1288
1289
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1290
1291
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1292

1293
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1294
1295
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1296
        The input pairs are used to build a list of prompts for the
1297
1298
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1299
1300
1301
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1302
        appropriate multi-modal models. For multi-modal inputs, ensure the
1303
        prompt structure matches the model's expected input format.
1304
1305

        Args:
1306
1307
1308
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1309
                the `data_2` list.
1310
            data_2: The data to pair with the query to form the input to
1311
                the LLM. Can be text or multi-modal data. See [PromptType]
1312
                [vllm.inputs.PromptType] for more details about the format of
1313
                each prompt.
1314
1315
1316
1317
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1318
            lora_request: LoRA request to use for generation, if any.
1319
1320
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1321
        Returns:
1322
            A list of `ScoringRequestOutput` objects containing the
1323
1324
            generated scores in the same order as the input prompts.
        """
1325
1326
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1327
        if runner_type != "pooling":
1328
1329
1330
1331
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1332

1333
1334
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1335
            raise ValueError("Score API is not supported by this model. "
1336
1337
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1338

1339
        if (model_config.is_cross_encoder
1340
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1341
            raise ValueError("Score API is only enabled for num_labels == 1.")
1342
1343
1344
1345

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1346
        tokenizer = self.get_tokenizer()
1347

1348
        if not model_config.is_multimodal_model:
1349
1350
1351
1352
1353

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1354
1355
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1396

1397
        if model_config.is_cross_encoder:
1398
1399
1400
1401
1402
1403
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1404
                pooling_params,
1405
                lora_request)
1406
        else:
1407
1408
            return self._embedding_score(
                tokenizer,
1409
1410
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1411
1412
                truncate_prompt_tokens,
                use_tqdm,
1413
                pooling_params,
1414
                lora_request)
1415

1416
1417
1418
1419
1420
1421
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1422
1423
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1424

1425
1426
1427
1428
1429
1430
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1431
        Args:
1432
1433
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1434
                is forgotten. Level 1 sleep is good for sleeping and waking
1435
1436
1437
1438
1439
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1440
                sleep is good for sleeping and waking up the engine to run a
1441
                different model or update the model, where previous model
1442
                weights are not needed. It reduces CPU memory pressure.
1443
        """
1444
        self.reset_prefix_cache()
1445
1446
        self.llm_engine.sleep(level=level)

1447
    def wake_up(self, tags: Optional[list[str]] = None):
1448
        """
1449
1450
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1451

1452
        Args:
1453
1454
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1455
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1456
                wake_up should be called with all tags (or None) before the
1457
1458
1459
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1460

1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1473
1474
    def _validate_and_add_requests(
        self,
1475
        prompts: Union[PromptType, Sequence[PromptType]],
1476
1477
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1478
        *,
1479
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1480
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1481
        priority: Optional[list[int]] = None,
1482
    ) -> None:
1483
        if isinstance(prompts, (str, dict)):
1484
            # Convert a single prompt to a list.
1485
            prompts = [prompts]
1486

1487
        num_requests = len(prompts)
1488
        if isinstance(params, Sequence) and len(params) != num_requests:
1489
            raise ValueError("The lengths of prompts and params "
1490
                             "must be the same.")
1491
        if isinstance(lora_request,
1492
                      Sequence) and len(lora_request) != num_requests:
1493
1494
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1495

1496
        for sp in params if isinstance(params, Sequence) else (params, ):
1497
1498
1499
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1500

Zhuohan Li's avatar
Zhuohan Li committed
1501
        # Add requests to the engine.
1502
1503
        it = prompts
        if use_tqdm:
1504
1505
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1506

1507
1508
        model_config = self.llm_engine.model_config

1509
        for i, prompt in enumerate(it):
1510

1511
1512
1513
1514
1515
            if isinstance(prompt, dict):
                self._validate_mm_data_and_uuids(
                    prompt.get("multi_modal_data"),
                    prompt.get("multi_modal_uuids"))

1516
1517
1518
1519
1520
1521
1522
            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1523
            self._add_request(
1524
                prompt,
1525
                params[i] if isinstance(params, Sequence) else params,
1526
                tokenization_kwargs=tokenization_kwargs,
1527
1528
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1529
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1530
            )
1531

1532
1533
1534
1535
1536
1537
1538
    def _validate_mm_data_and_uuids(
            self,
            multi_modal_data: Optional[Any],  # MultiModalDataDict
            multi_modal_uuids: Optional[Any],  # MultiModalUUIDDict
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1539
        then its corresponding UUID must be set.
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
                        if multi_modal_uuids is None or modality not in multi_modal_uuids or multi_modal_uuids[  # noqa: E501
                                modality] is None:
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
                                f"but UUID is not provided")
                        else:
                            if len(
                                    multi_modal_uuids[modality]
                            ) <= i or multi_modal_uuids[modality][i] is None:
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
                                    f"but UUID is not provided")
            else:
                if data is None and (multi_modal_uuids is None
                                     or modality not in multi_modal_uuids
                                     or multi_modal_uuids[modality] is None):
                    raise ValueError(f"Multi-modal data for {modality} is None"
                                     f" but UUID is not provided")

1567
    def _add_request(
nunjunj's avatar
nunjunj committed
1568
        self,
1569
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1570
        params: Union[SamplingParams, PoolingParams],
1571
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1572
        lora_request: Optional[LoRARequest] = None,
1573
        priority: int = 0,
1574
1575
    ) -> None:
        request_id = str(next(self.request_counter))
1576
1577
        self.llm_engine.add_request(
            request_id,
1578
            prompt,
1579
1580
            params,
            lora_request=lora_request,
1581
            tokenization_kwargs=tokenization_kwargs,
1582
            priority=priority,
nunjunj's avatar
nunjunj committed
1583
        )
1584

1585
    def _run_engine(
1586
1587
1588
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1589
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1590
1591
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1592
            num_requests = self.llm_engine.get_num_unfinished_requests()
1593
1594
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1595
1596
1597
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1598
1599
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1600
            )
1601

Zhuohan Li's avatar
Zhuohan Li committed
1602
        # Run the engine.
1603
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1604
1605
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1606
1607
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1608
            for output in step_outputs:
1609
                if output.finished:
1610
1611
                    outputs.append(output)
                    if use_tqdm:
1612
1613
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1614
                            n = len(output.outputs)
1615
                            assert output.prompt_token_ids is not None
1616
                            total_in_toks += len(output.prompt_token_ids) * n
1617
1618
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1619
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1620
1621
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1622
1623
1624
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1625
                            pbar.update(n)
1626
1627
                        else:
                            pbar.update(1)
1628
1629
                        if pbar.n == num_requests:
                            pbar.refresh()
1630

1631
1632
        if use_tqdm:
            pbar.close()
1633
1634
1635
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1636
        return sorted(outputs, key=lambda x: int(x.request_id))