llm.py 72.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
    CompilationConfig,
22
    PoolerConfig,
23
24
25
    StructuredOutputsConfig,
    is_init_field,
)
26
from vllm.config.compilation import CompilationMode
27
from vllm.config.model import (
28
29
    ConvertOption,
    HfOverrides,
30
    ModelDType,
31
    RunnerOption,
32
    TokenizerMode,
33
)
34
from vllm.engine.arg_utils import EngineArgs
35
from vllm.engine.protocol import Device
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
60
from vllm.inputs.parse import get_prompt_components
61
from vllm.logger import init_logger
62
from vllm.lora.request import LoRARequest
63
from vllm.model_executor.layers.quantization import QuantizationMethods
64
65
66
67
68
69
70
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
71
from vllm.platforms import current_platform
72
from vllm.pooling_params import PoolingParams
73
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
74
from vllm.tasks import PoolingTask
75
76
77
78
79
from vllm.transformers_utils.tokenizer import (
    AnyTokenizer,
    MistralTokenizer,
    get_cached_tokenizer,
)
yhu422's avatar
yhu422 committed
80
from vllm.usage.usage_lib import UsageContext
81
from vllm.utils.collection_utils import as_iter, is_list_of
82
from vllm.utils.counter import Counter
83
from vllm.v1.engine import EngineCoreRequest
84
from vllm.v1.engine.llm_engine import LLMEngine
85
from vllm.v1.sample.logits_processor import LogitsProcessor
86

87
88
89
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

90
91
logger = init_logger(__name__)

92
93
_R = TypeVar("_R", default=Any)

94
95

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
96
97
98
99
100
101
102
103
104
105
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
106
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
107
108
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
109
110
111
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
112
113
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
114
115
116
117
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
118
        allowed_media_domains: If set, only media URLs that belong to this
119
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
120
121
122
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
123
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
124
125
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
126
        quantization: The method used to quantize the model weights. Currently,
127
            we support "awq", "gptq", and "fp8" (experimental).
128
129
130
131
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
132
133
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
134
135
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
136
137
138
139
140
141
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
142
143
144
145
146
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
147
            compared with using gpu_memory_utilization. Note that
148
149
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
150
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
151
152
153
154
155
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
156
157
158
159
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
160
161
162
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
163
164
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
165
        hf_token: The token to use as HTTP bearer authorization for remote files
166
            . If `True`, will use the token generated when running
167
            `huggingface-cli login` (stored in `~/.huggingface`).
168
169
170
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
171
172
173
174
175
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
176
177
178
179
180
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
181
        compilation_config: Either an integer or a dictionary. If it is an
182
            integer, it is used as the mode of compilation optimization. If it
183
            is a dictionary, it can specify the full compilation configuration.
184
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
185

186
187
    Note:
        This class is intended to be used for offline inference. For online
188
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
189
    """
190
191
192
193

    def __init__(
        self,
        model: str,
194
        *,
195
196
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
197
        tokenizer: str | None = None,
198
        tokenizer_mode: TokenizerMode = "auto",
199
        skip_tokenizer_init: bool = False,
200
        trust_remote_code: bool = False,
201
        allowed_local_media_path: str = "",
202
        allowed_media_domains: list[str] | None = None,
203
        tensor_parallel_size: int = 1,
204
        dtype: ModelDType = "auto",
205
206
207
208
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
        seed: int | None = None,
209
        gpu_memory_utilization: float = 0.9,
210
        swap_space: float = 4,
211
        cpu_offload_gb: float = 0,
212
        enforce_eager: bool = False,
213
        disable_custom_all_reduce: bool = False,
214
215
216
217
218
219
220
221
222
223
224
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        override_pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
225
        **kwargs: Any,
226
    ) -> None:
227
        """LLM constructor."""
228

229
230
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
231

232
233
234
235
236
237
238
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

239
        if "kv_transfer_config" in kwargs and isinstance(
240
241
            kwargs["kv_transfer_config"], dict
        ):
242
            from vllm.config.kv_transfer import KVTransferConfig
243

244
245
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
246
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
247
248
249
250
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
251
252
253
                    raw_config_dict,
                    e,
                )
254
255
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
256
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
257

258
259
260
        if hf_overrides is None:
            hf_overrides = {}

261
        if compilation_config is not None:
262
            if isinstance(compilation_config, int):
263
264
265
                compilation_config_instance = CompilationConfig(
                    mode=CompilationMode(compilation_config)
                )
266
267
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
268
269
270
271
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
272
273
                    }
                )
274
275
            else:
                compilation_config_instance = compilation_config
276
        else:
277
            compilation_config_instance = CompilationConfig()
278

279
280
281
282
283
284
285
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
286
287
                    }
                )
288
289
290
291
292
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

293
        # warn about single-process data parallel usage.
294
295
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
296
297
298
299
300
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
301
            raise ValueError(
302
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
303
304
305
306
307
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
308
        engine_args = EngineArgs(
309
            model=model,
310
311
            runner=runner,
            convert=convert,
312
            tokenizer=tokenizer,
313
            tokenizer_mode=tokenizer_mode,
314
            skip_tokenizer_init=skip_tokenizer_init,
315
            trust_remote_code=trust_remote_code,
316
            allowed_local_media_path=allowed_local_media_path,
317
            allowed_media_domains=allowed_media_domains,
318
319
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
320
            quantization=quantization,
321
            revision=revision,
322
            tokenizer_revision=tokenizer_revision,
323
324
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
325
            kv_cache_memory_bytes=kv_cache_memory_bytes,
326
            swap_space=swap_space,
327
            cpu_offload_gb=cpu_offload_gb,
328
            enforce_eager=enforce_eager,
329
            disable_custom_all_reduce=disable_custom_all_reduce,
330
            hf_token=hf_token,
331
            hf_overrides=hf_overrides,
332
            mm_processor_kwargs=mm_processor_kwargs,
333
            pooler_config=pooler_config,
334
            override_pooler_config=override_pooler_config,
335
            structured_outputs_config=structured_outputs_instance,
336
            compilation_config=compilation_config_instance,
337
            logits_processors=logits_processors,
338
339
            **kwargs,
        )
340

341
342
        log_non_default_args(engine_args)

343
344
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
345
346
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
347
        self.engine_class = type(self.llm_engine)
348

349
        self.request_counter = Counter()
350
        self.default_sampling_params: dict[str, Any] | None = None
351

352
353
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
354
355
        self.supported_tasks = supported_tasks

356
357
358
        self.model_config = self.llm_engine.model_config
        self.processor = self.llm_engine.processor
        self.io_processor = self.llm_engine.io_processor
359

360
361
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
362

363
    @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
364
    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
365
366
367
368
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
369
            self.llm_engine.tokenizer = tokenizer
370
        else:
371
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
372

373
374
375
376
    def reset_mm_cache(self) -> None:
        self.processor.clear_mm_cache()
        self.llm_engine.reset_mm_cache()

377
    def get_default_sampling_params(self) -> SamplingParams:
378
        if self.default_sampling_params is None:
379
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
380
381
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
382
383
        return SamplingParams()

384
385
    def generate(
        self,
386
387
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
388
        *,
389
390
391
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
392
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
        """Generates the completions for the input prompts.

395
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
396
397
398
399
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
400
            prompts: The prompts to the LLM. You may pass a sequence of prompts
401
                for batch inference. See [PromptType][vllm.inputs.PromptType]
402
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
403
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
404
405
406
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
407
                prompts and it is paired one by one with the prompt.
408
409
410
411
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
412
            lora_request: LoRA request to use for generation, if any.
413
414
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416

        Returns:
417
            A list of `RequestOutput` objects containing the
418
            generated completions in the same order as the input prompts.
419

420
421
422
423
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
424
        """
425
        model_config = self.model_config
426
427
        runner_type = model_config.runner_type
        if runner_type != "generate":
428
429
430
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
431
432
                "generative model."
            )
433

434
435
        if sampling_params is None:
            # Use default sampling params.
436
            sampling_params = self.get_default_sampling_params()
437

438
        # Add any modality specific loras to the corresponding prompts
439
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
440

441
        self._validate_and_add_requests(
442
            prompts=prompts,
443
            params=sampling_params,
444
            use_tqdm=use_tqdm,
445
            lora_request=lora_request,
446
447
            priority=priority,
        )
448

449
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
450
        return self.engine_class.validate_outputs(outputs, RequestOutput)
451

452
    def _get_modality_specific_lora_reqs(
453
        self,
454
455
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
456
    ):
457
458
459
460
461
462
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
463
464
        if (
            lora_config is None
465
            or not self.model_config.is_multimodal_model
466
467
            or (lora_config and lora_config.default_mm_loras is None)
        ):
468
469
            return lora_request

470
471
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
472

473
474
475
476
477
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
478
479
480

        return [
            self._resolve_single_prompt_mm_lora(
481
                prompt,
482
483
                opt_lora_req,
                lora_config.default_mm_loras,
484
485
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
486
487
        ]

488
489
490
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
491
492
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
493
494
495
496
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
497
            or not (mm_data := prompt.get("multi_modal_data") or {})
498
        ):
499
500
            return lora_request

501
502
503
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
504
505
506
507
508
509
510
511
512
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
513
514
515
                " will be skipped",
                intersection,
            )
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
531
532
                    "lora_request as we only apply one LoRARequest per prompt"
                )
533
534
535
536
537
538
539
540
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

541
542
    def collective_rpc(
        self,
543
544
        method: str | Callable[..., _R],
        timeout: float | None = None,
545
        args: tuple = (),
546
        kwargs: dict[str, Any] | None = None,
547
    ) -> list[_R]:
548
549
550
551
552
553
554
555
556
557
558
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
559
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
560
561
562
563
564
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
565

566
567
568
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
569
        """
570
571

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
572
573

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
574
        """
575
576
        Run a function directly on the model inside each worker,
        returning the result for each of them.
577
578
579
580
581
582

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
583
        """
584
        return self.llm_engine.apply_model(func)
585

586
587
    def _get_beam_search_lora_requests(
        self,
588
589
590
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
591
        """Get the optional lora request corresponding to each prompt."""
592
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
593
            raise ValueError(
594
595
                "Lora request list should be the same length as the prompts"
            )
596
597
598
599
600
601

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

602
603
    def beam_search(
        self,
604
        prompts: list[TokensPrompt | TextPrompt],
605
        params: BeamSearchParams,
606
        lora_request: list[LoRARequest] | LoRARequest | None = None,
607
        use_tqdm: bool = False,
608
        concurrency_limit: int | None = None,
609
    ) -> list[BeamSearchOutput]:
610
611
612
613
614
615
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
616
            params: The beam search parameters.
617
            lora_request: LoRA request to use for generation, if any.
618
            use_tqdm: Whether to use tqdm to display the progress bar.
619
620
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
621
        """
622
623
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
624
625
626
627
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
628
629
        length_penalty = params.length_penalty

630
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
631

632
633
634
635
636
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
637

638
639
640
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
641
642
                "Disabling progress bar."
            )
643
644
645
646
647
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

648
649
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
650
651
652
653
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
654
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
655
            return TokensPrompt(**token_prompt_kwargs)
656

657
658
659
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
660
661
662
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width, max_tokens=1, temperature=temperature
        )
663
        instances: list[BeamSearchInstance] = []
664

665
        for lora_req, prompt in zip(lora_requests, prompts):
666
667
668
669
670
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
671
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
672

673
674
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
675
676
677
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
678

679
            instances.append(
680
681
682
683
684
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
685
686
                ),
            )
687

688
        for prompt_start in range(0, len(prompts), concurrency_limit):
689
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
690
691
692

            token_iter = range(max_tokens)
            if use_tqdm:
693
694
695
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
696
697
698
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
699
700
                    "reflect instance-level progress."
                )
701
702
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
703
704
                    sum((instance.beams for instance in instances_batch), [])
                )
705
706
                pos = [0] + list(
                    itertools.accumulate(
707
708
709
                        len(instance.beams) for instance in instances_batch
                    )
                )
710
                instance_start_and_end: list[tuple[int, int]] = list(
711
712
                    zip(pos[:-1], pos[1:])
                )
713
714
715
716
717
718

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
719
720
721
722
723
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
724
725
726

                # only runs for one step
                # we don't need to use tqdm here
727
728
729
730
731
732
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
733

734
735
736
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
737
738
739
740
741
742
743
744
745
746
747
748
749
750
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
751
                                    logprobs=current_beam.logprobs + [logprobs],
752
                                    lora_request=current_beam.lora_request,
753
754
755
756
757
758
759
760
761
762
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
763
764
765
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
766
767
768
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
769
                    instance.beams = sorted_beams[:beam_width]
770
771
772
773

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
774
775
776
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
777
778
779
780
781
782
783
784
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

785
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
786
        self,
787
788
789
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
790
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
791
        add_generation_prompt: bool = True,
792
        continue_final_message: bool = False,
793
794
795
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
796
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
797
        """
798
799
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
800

801
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
802
        Returns:
803
804
805
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
806
        """
807
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
808

809
810
        # Handle multi and single conversations
        if is_list_of(messages, list):
811
            # messages is list[list[...]]
812
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
813
        else:
814
            # messages is list[...]
815
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
816

817
        tokenizer = self.get_tokenizer()
818
        model_config = self.model_config
819
820
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
821
            tools,
822
823
            chat_template_content_format,
            tokenizer,
824
            model_config=model_config,
825
826
        )

827
828
829
830
831
832
833
834
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

835
        prompts: list[TokensPrompt] = []
836
837

        for msgs in list_of_messages:
838
839
840
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
841
            conversation, mm_data, mm_uuids = parse_chat_messages(
842
843
844
845
846
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
847
848

            if isinstance(tokenizer, MistralTokenizer):
849
                prompt_token_ids = apply_mistral_chat_template(
850
851
                    tokenizer,
                    messages=msgs,
852
                    **_chat_template_kwargs,
853
854
                )
            else:
855
                prompt_str = apply_hf_chat_template(
856
                    tokenizer=tokenizer,
857
                    conversation=conversation,
858
                    model_config=model_config,
859
                    **_chat_template_kwargs,
860
                )
861
862
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
863
864
865
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
866

867
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
868
869
870
871

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

872
873
874
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

875
876
877
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

878
            prompts.append(prompt)
879

880
881
882
883
        return prompts

    def chat(
        self,
884
885
886
887
888
889
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
890
891
892
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
893
894
895
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
959
        return self.generate(
960
            prompts,
961
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
962
963
964
965
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

966
967
    def encode(
        self,
968
969
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
970
        *,
971
972
973
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
974
        pooling_task: PoolingTask | None = None,
975
        tokenization_kwargs: dict[str, Any] | None = None,
976
    ) -> list[PoolingRequestOutput]:
977
978
        """Apply pooling to the hidden states corresponding to the input
        prompts.
979

980
        This class automatically batches the given prompts, considering
981
982
983
984
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
985
            prompts: The prompts to the LLM. You may pass a sequence of prompts
986
                for batch inference. See [PromptType][vllm.inputs.PromptType]
987
                for more details about the format of each prompt.
988
989
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
990
991
992
993
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
994
            lora_request: LoRA request to use for generation, if any.
995
            pooling_task: Override the pooling task to use.
996
997
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
998
999

        Returns:
1000
            A list of `PoolingRequestOutput` objects containing the
1001
            pooled hidden states in the same order as the input prompts.
1002

1003
1004
1005
1006
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1007
        """
1008

1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        error_str = (
            "pooling_task required for `LLM.encode`\n"
            "Please use one of the more specific methods or set the "
            "pooling_task when using `LLM.encode`:\n"
            "  - For embeddings, use `LLM.embed(...)` "
            'or `pooling_task="embed"`.\n'
            "  - For classification logits, use `LLM.classify(...)` "
            'or `pooling_task="classify"`.\n'
            "  - For similarity scores, use `LLM.score(...)`.\n"
            "  - For rewards, use `LLM.reward(...)` "
            'or `pooling_task="token_classify"`\n'
            "  - For token classification, "
            'use `pooling_task="token_classify"`\n'
            '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
        )
1024

1025
        if pooling_task is None:
1026
            raise ValueError(error_str)
1027

1028
        model_config = self.model_config
1029
        runner_type = model_config.runner_type
1030
        if runner_type != "pooling":
1031
1032
1033
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1034
1035
                "pooling model."
            )
1036

1037
1038
1039
1040
1041
1042
1043
1044
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1045
1046
                    "offline inference example for more details."
                )
1047
1048
1049
1050
1051
1052
1053

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens

1082
        self._validate_and_add_requests(
1083
            prompts=prompts,
1084
            params=pooling_params,
1085
            use_tqdm=use_tqdm,
1086
            lora_request=lora_request,
1087
1088
        )

1089
        outputs = self._run_engine(use_tqdm=use_tqdm)
1090
1091

        model_outputs = self.engine_class.validate_outputs(
1092
1093
            outputs, PoolingRequestOutput
        )
1094
1095
1096
1097
1098

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1099
1100
                model_output=model_outputs
            )
1101
1102

            return [
1103
1104
1105
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1106
1107
1108
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1109
1110
1111
                    prompt_token_ids=[],
                    finished=True,
                )
1112
1113
1114
            ]
        else:
            return model_outputs
1115

1116
1117
    def embed(
        self,
1118
        prompts: PromptType | Sequence[PromptType],
1119
        *,
1120
1121
1122
1123
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1124
    ) -> list[EmbeddingRequestOutput]:
1125
1126
1127
1128
1129
1130
1131
1132
1133
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1134
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1135
                for more details about the format of each prompt.
1136
1137
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1138
1139
1140
1141
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1142
1143
1144
            lora_request: LoRA request to use for generation, if any.

        Returns:
1145
            A list of `EmbeddingRequestOutput` objects containing the
1146
1147
            embedding vectors in the same order as the input prompts.
        """
1148
        if "embed" not in self.supported_tasks:
1149
1150
            raise ValueError(
                "Embedding API is not supported by this model. "
1151
1152
                "Try converting the model using `--convert embed`."
            )
1153

1154
1155
1156
1157
1158
1159
1160
1161
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1162
1163
1164
1165
1166

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1167
        prompts: PromptType | Sequence[PromptType],
1168
        *,
1169
1170
1171
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1172
    ) -> list[ClassificationRequestOutput]:
1173
1174
1175
1176
1177
1178
1179
1180
1181
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1182
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1183
                for more details about the format of each prompt.
1184
1185
1186
1187
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1188
            lora_request: LoRA request to use for generation, if any.
1189
1190
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1191
        Returns:
1192
            A list of `ClassificationRequestOutput` objects containing the
1193
1194
            embedding vectors in the same order as the input prompts.
        """
1195
        if "classify" not in self.supported_tasks:
1196
            raise ValueError(
1197
                "Classification API is not supported by this model. "
1198
1199
                "Try converting the model using `--convert classify`."
            )
1200

1201
1202
1203
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1204
            pooling_params=pooling_params,
1205
1206
1207
            lora_request=lora_request,
            pooling_task="classify",
        )
1208
1209
1210

        return [ClassificationRequestOutput.from_base(item) for item in items]

1211
1212
    def reward(
        self,
1213
        prompts: PromptType | Sequence[PromptType],
1214
1215
        /,
        *,
1216
1217
1218
1219
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1220
1221
1222
1223
1224
1225
1226
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1227
                for more details about the format of each prompt.
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1246
            pooling_task="token_classify",
1247
1248
        )

1249
1250
1251
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1252
1253
1254
1255
1256
1257
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1258
1259
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1260
            text_1 + text_2,
1261
            truncate_prompt_tokens=truncate_prompt_tokens,
1262
1263
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1264
            pooling_params=pooling_params,
1265
1266
            pooling_task="embed",
        )
1267

1268
1269
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1270
1271
1272
1273

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1274
1275
1276
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1277

1278
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1279
1280
1281
1282
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1283
        tokenizer: AnyTokenizer,
1284
1285
1286
1287
1288
1289
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1290
    ) -> list[ScoringRequestOutput]:
1291
        model_config = self.model_config
1292
1293

        if isinstance(tokenizer, MistralTokenizer):
1294
            raise ValueError("Score API is not supported for Mistral tokenizer")
1295

1296
1297
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1298

1299
1300
1301
1302
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1303
        pooling_params_list = list[PoolingParams]()
1304

1305
        tokenization_kwargs: dict[str, Any] = {}
1306

1307
1308
1309
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1310

1311
        prompts = list[PromptType]()
1312

1313
1314
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1315
1316
1317
1318
1319
1320
1321
1322
1323
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1324
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1325
1326
1327
1328
1329
1330
1331
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1332
            prompts.append(engine_prompt)
1333
1334

        self._validate_and_add_requests(
1335
            prompts=prompts,
1336
            params=pooling_params_list,
1337
            use_tqdm=use_tqdm,
1338
1339
1340
1341
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1342
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1343
1344
1345

        return [ScoringRequestOutput.from_base(item) for item in items]

1346
1347
    def score(
        self,
1348
1349
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1350
        /,
1351
        *,
1352
1353
1354
1355
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1356
    ) -> list[ScoringRequestOutput]:
1357
1358
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1359

1360
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1361
1362
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1363
        The input pairs are used to build a list of prompts for the
1364
1365
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1366
1367
1368
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1369
        appropriate multi-modal models. For multi-modal inputs, ensure the
1370
        prompt structure matches the model's expected input format.
1371
1372

        Args:
1373
1374
1375
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1376
                the `data_2` list.
1377
            data_2: The data to pair with the query to form the input to
1378
                the LLM. Can be text or multi-modal data. See [PromptType]
1379
                [vllm.inputs.PromptType] for more details about the format of
1380
                each prompt.
1381
1382
1383
1384
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1385
            lora_request: LoRA request to use for generation, if any.
1386
1387
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1388
        Returns:
1389
            A list of `ScoringRequestOutput` objects containing the
1390
1391
            generated scores in the same order as the input prompts.
        """
1392
        model_config = self.model_config
1393
        runner_type = model_config.runner_type
1394
        if runner_type != "pooling":
1395
1396
1397
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1398
1399
                "pooling model."
            )
1400

1401
1402
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1403
1404
1405
1406
1407
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1408

1409
1410
1411
1412
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1413
            raise ValueError("Score API is only enabled for num_labels == 1.")
1414
1415
1416
1417

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1418
        tokenizer = self.get_tokenizer()
1419

1420
        if not model_config.is_multimodal_model:
1421

1422
            def check_data_type(
1423
1424
1425
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1426
            ):
1427
                if isinstance(data, dict) and "content" in data:
1428
1429
1430
1431
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1432
1433
1434
1435
1436
1437
1438

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1439
1440
1441
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1442
1443
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1444
1445
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1474

1475
        if model_config.is_cross_encoder:
1476
1477
1478
1479
1480
1481
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1482
                pooling_params,
1483
1484
                lora_request,
            )
1485
        else:
1486
1487
            return self._embedding_score(
                tokenizer,
1488
1489
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1490
1491
                truncate_prompt_tokens,
                use_tqdm,
1492
                pooling_params,
1493
1494
                lora_request,
            )
1495

1496
1497
1498
1499
1500
1501
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1502
1503
    def reset_prefix_cache(self, device: Device | None = None) -> None:
        self.llm_engine.reset_prefix_cache(device)
1504

1505
1506
1507
1508
1509
1510
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1511
        Args:
1512
1513
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1514
                is forgotten. Level 1 sleep is good for sleeping and waking
1515
1516
1517
1518
1519
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1520
                sleep is good for sleeping and waking up the engine to run a
1521
                different model or update the model, where previous model
1522
                weights are not needed. It reduces CPU memory pressure.
1523
        """
1524
        self.reset_prefix_cache()
1525
1526
        self.llm_engine.sleep(level=level)

1527
    def wake_up(self, tags: list[str] | None = None):
1528
        """
1529
1530
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1531

1532
        Args:
1533
1534
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1535
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1536
                wake_up should be called with all tags (or None) before the
1537
1538
1539
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1540

1541
1542
1543
1544
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1545
            A `MetricSnapshot` instance capturing the current state
1546
1547
1548
1549
1550
1551
1552
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1553
1554
    def _validate_and_add_requests(
        self,
1555
1556
1557
1558
1559
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1560
        *,
1561
1562
1563
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1564
    ) -> None:
1565
        if isinstance(prompts, (str, dict)):
1566
            # Convert a single prompt to a list.
1567
            prompts = [prompts]  # type: ignore[list-item]
1568

1569
        num_requests = len(prompts)
1570
        if isinstance(params, Sequence) and len(params) != num_requests:
1571
1572
1573
1574
1575
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )
1576
1577
1578
1579
1580
1581
        if priority is not None and len(priority) != num_requests:
            raise ValueError(
                "The lengths of prompts "
                f"({num_requests}) and priority ({len(priority)}) "
                "must be the same."
            )
1582
1583

        for sp in params if isinstance(params, Sequence) else (params,):
1584
1585
1586
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1587

Zhuohan Li's avatar
Zhuohan Li committed
1588
        # Add requests to the engine.
1589
1590
        it = prompts
        if use_tqdm:
1591
1592
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1593

1594
        added_request_ids: list[str] = []
1595

1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
        try:
            for i, prompt in enumerate(it):
                if isinstance(prompt, dict):
                    self._validate_mm_data_and_uuids(
                        prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                    )
                request_id = self._add_request(
                    prompt,
                    params[i] if isinstance(params, Sequence) else params,
                    lora_request=lora_request[i]
                    if isinstance(lora_request, Sequence)
                    else lora_request,
                    priority=priority[i] if priority else 0,
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
                self.llm_engine.abort_request(added_request_ids)
            raise e
1615

1616
    def _validate_mm_data_and_uuids(
1617
        self,
1618
1619
        multi_modal_data: Any | None,  # MultiModalDataDict
        multi_modal_uuids: Any | None,  # MultiModalUUIDDict
1620
1621
1622
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1623
        then its corresponding UUID must be set.
1624
1625
1626
1627
1628
1629
1630
1631
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1632
1633
1634
1635
1636
1637
1638
1639
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1640
1641
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1642
1643
                                f"but UUID is not provided"
                            )
1644
                        else:
1645
1646
1647
1648
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1649
1650
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1651
1652
                                    f"but UUID is not provided"
                                )
1653
            else:
1654
1655
1656
1657
1658
1659
1660
1661
1662
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1663

1664
1665
1666
1667
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1668
        params: SamplingParams | PoolingParams,
1669
        *,
1670
        lora_request: LoRARequest | None,
1671
1672
1673
1674
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
        tokenization_kwargs: dict[str, Any] = {}
1675
1676
1677
1678
1679
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1680

1681
        engine_request = self.processor.process_inputs(
1682
1683
1684
1685
1686
1687
1688
1689
1690
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1691
    def _add_request(
nunjunj's avatar
nunjunj committed
1692
        self,
1693
        prompt: PromptType,
1694
1695
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1696
        priority: int = 0,
1697
    ) -> str:
1698
        prompt_text, _, _ = get_prompt_components(prompt)
1699
        request_id = str(next(self.request_counter))
1700
1701

        engine_request, tokenization_kwargs = self._process_inputs(
1702
            request_id,
1703
            prompt,
1704
1705
            params,
            lora_request=lora_request,
1706
1707
1708
1709
1710
1711
1712
1713
            priority=priority,
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1714
            tokenization_kwargs=tokenization_kwargs,
1715
            priority=priority,
1716
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1717
        )
1718
        return request_id
1719

1720
    def _run_engine(
1721
1722
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1723
1724
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1725
            num_requests = self.llm_engine.get_num_unfinished_requests()
1726
1727
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1728
1729
1730
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1731
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1732
            )
1733

Zhuohan Li's avatar
Zhuohan Li committed
1734
        # Run the engine.
1735
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1736
1737
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1738
1739
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1740
            for output in step_outputs:
1741
                if output.finished:
1742
1743
                    outputs.append(output)
                    if use_tqdm:
1744
1745
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1746
                            n = len(output.outputs)
1747
                            assert output.prompt_token_ids is not None
1748
                            total_in_toks += len(output.prompt_token_ids) * n
1749
1750
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1751
1752
1753
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1754
1755
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1756
1757
                                f"output: {out_spd:.2f} toks/s"
                            )
1758
                            pbar.update(n)
1759
1760
                        else:
                            pbar.update(1)
1761
1762
                        if pbar.n == num_requests:
                            pbar.refresh()
1763

1764
1765
        if use_tqdm:
            pbar.close()
1766
1767
1768
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1769
        return sorted(outputs, key=lambda x: int(x.request_id))