llm.py 72.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
    CompilationConfig,
22
    PoolerConfig,
23
24
25
    StructuredOutputsConfig,
    is_init_field,
)
26
from vllm.config.model import (
27
28
    ConvertOption,
    HfOverrides,
29
    ModelDType,
30
    RunnerOption,
31
    TokenizerMode,
32
)
33
from vllm.engine.arg_utils import EngineArgs
34
from vllm.engine.protocol import Device
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
59
from vllm.inputs.parse import get_prompt_components
60
from vllm.logger import init_logger
61
from vllm.lora.request import LoRARequest
62
from vllm.model_executor.layers.quantization import QuantizationMethods
63
64
65
66
67
68
69
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
70
from vllm.platforms import current_platform
71
from vllm.pooling_params import PoolingParams
72
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
73
from vllm.tasks import PoolingTask
74
75
76
77
78
from vllm.transformers_utils.tokenizer import (
    AnyTokenizer,
    MistralTokenizer,
    get_cached_tokenizer,
)
yhu422's avatar
yhu422 committed
79
from vllm.usage.usage_lib import UsageContext
80
from vllm.utils.collection_utils import as_iter, is_list_of
81
from vllm.utils.counter import Counter
82
from vllm.v1.engine import EngineCoreRequest
83
from vllm.v1.engine.llm_engine import LLMEngine
84
from vllm.v1.sample.logits_processor import LogitsProcessor
85

86
87
88
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

89
90
logger = init_logger(__name__)

91
92
_R = TypeVar("_R", default=Any)

93
94

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
99
100
101
102
103
104
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
105
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
106
107
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
108
109
110
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
111
112
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
113
114
115
116
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
117
        allowed_media_domains: If set, only media URLs that belong to this
118
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
123
124
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
125
        quantization: The method used to quantize the model weights. Currently,
126
            we support "awq", "gptq", and "fp8" (experimental).
127
128
129
130
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
131
132
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
133
134
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
135
136
137
138
139
140
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
141
142
143
144
145
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
146
            compared with using gpu_memory_utilization. Note that
147
148
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
149
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
150
151
152
153
154
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
155
156
157
158
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
159
160
161
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
162
163
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
164
        hf_token: The token to use as HTTP bearer authorization for remote files
165
            . If `True`, will use the token generated when running
166
            `huggingface-cli login` (stored in `~/.huggingface`).
167
168
169
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
170
171
172
173
174
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
175
176
177
178
179
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
180
        compilation_config: Either an integer or a dictionary. If it is an
181
            integer, it is used as the mode of compilation optimization. If it
182
            is a dictionary, it can specify the full compilation configuration.
183
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
184

185
186
    Note:
        This class is intended to be used for offline inference. For online
187
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
188
    """
189
190
191
192

    def __init__(
        self,
        model: str,
193
        *,
194
195
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
196
        tokenizer: str | None = None,
197
        tokenizer_mode: TokenizerMode = "auto",
198
        skip_tokenizer_init: bool = False,
199
        trust_remote_code: bool = False,
200
        allowed_local_media_path: str = "",
201
        allowed_media_domains: list[str] | None = None,
202
        tensor_parallel_size: int = 1,
203
        dtype: ModelDType = "auto",
204
205
206
207
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
        seed: int | None = None,
208
        gpu_memory_utilization: float = 0.9,
209
        swap_space: float = 4,
210
        cpu_offload_gb: float = 0,
211
        enforce_eager: bool = False,
212
        disable_custom_all_reduce: bool = False,
213
214
215
216
217
218
219
220
221
222
223
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        override_pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
224
        **kwargs: Any,
225
    ) -> None:
226
        """LLM constructor."""
227

228
229
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
230

231
232
233
234
235
236
237
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

238
        if "kv_transfer_config" in kwargs and isinstance(
239
240
            kwargs["kv_transfer_config"], dict
        ):
241
            from vllm.config.kv_transfer import KVTransferConfig
242

243
244
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
245
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
246
247
248
249
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
250
251
252
                    raw_config_dict,
                    e,
                )
253
254
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
255
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
256

257
258
259
        if hf_overrides is None:
            hf_overrides = {}

260
        if compilation_config is not None:
261
            if isinstance(compilation_config, int):
262
                compilation_config_instance = CompilationConfig(mode=compilation_config)
263
264
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
265
266
267
268
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
269
270
                    }
                )
271
272
            else:
                compilation_config_instance = compilation_config
273
        else:
274
            compilation_config_instance = CompilationConfig()
275

276
277
278
279
280
281
282
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
283
284
                    }
                )
285
286
287
288
289
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

290
        # warn about single-process data parallel usage.
291
292
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
293
294
295
296
297
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
298
            raise ValueError(
299
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
300
301
302
303
304
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
305
        engine_args = EngineArgs(
306
            model=model,
307
308
            runner=runner,
            convert=convert,
309
            tokenizer=tokenizer,
310
            tokenizer_mode=tokenizer_mode,
311
            skip_tokenizer_init=skip_tokenizer_init,
312
            trust_remote_code=trust_remote_code,
313
            allowed_local_media_path=allowed_local_media_path,
314
            allowed_media_domains=allowed_media_domains,
315
316
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
317
            quantization=quantization,
318
            revision=revision,
319
            tokenizer_revision=tokenizer_revision,
320
321
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
322
            kv_cache_memory_bytes=kv_cache_memory_bytes,
323
            swap_space=swap_space,
324
            cpu_offload_gb=cpu_offload_gb,
325
            enforce_eager=enforce_eager,
326
            disable_custom_all_reduce=disable_custom_all_reduce,
327
            hf_token=hf_token,
328
            hf_overrides=hf_overrides,
329
            mm_processor_kwargs=mm_processor_kwargs,
330
            pooler_config=pooler_config,
331
            override_pooler_config=override_pooler_config,
332
            structured_outputs_config=structured_outputs_instance,
333
            compilation_config=compilation_config_instance,
334
            logits_processors=logits_processors,
335
336
            **kwargs,
        )
337

338
339
        log_non_default_args(engine_args)

340
341
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
342
343
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
344
        self.engine_class = type(self.llm_engine)
345

346
        self.request_counter = Counter()
347
        self.default_sampling_params: dict[str, Any] | None = None
348

349
350
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
351
352
        self.supported_tasks = supported_tasks

353
354
355
        self.model_config = self.llm_engine.model_config
        self.processor = self.llm_engine.processor
        self.io_processor = self.llm_engine.io_processor
356

357
358
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
359

360
    @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
361
    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
362
363
364
365
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
366
            self.llm_engine.tokenizer = tokenizer
367
        else:
368
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
369

370
371
372
373
    def reset_mm_cache(self) -> None:
        self.processor.clear_mm_cache()
        self.llm_engine.reset_mm_cache()

374
    def get_default_sampling_params(self) -> SamplingParams:
375
        if self.default_sampling_params is None:
376
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
377
378
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
379
380
        return SamplingParams()

381
382
    def generate(
        self,
383
384
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
385
        *,
386
387
388
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
389
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
390
391
        """Generates the completions for the input prompts.

392
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
393
394
395
396
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
397
            prompts: The prompts to the LLM. You may pass a sequence of prompts
398
                for batch inference. See [PromptType][vllm.inputs.PromptType]
399
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
400
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
401
402
403
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
404
                prompts and it is paired one by one with the prompt.
405
406
407
408
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
409
            lora_request: LoRA request to use for generation, if any.
410
411
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413

        Returns:
414
            A list of `RequestOutput` objects containing the
415
            generated completions in the same order as the input prompts.
416

417
418
419
420
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
421
        """
422
        model_config = self.model_config
423
424
        runner_type = model_config.runner_type
        if runner_type != "generate":
425
426
427
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
428
429
                "generative model."
            )
430

431
432
        if sampling_params is None:
            # Use default sampling params.
433
            sampling_params = self.get_default_sampling_params()
434

435
        # Add any modality specific loras to the corresponding prompts
436
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
437

438
        self._validate_and_add_requests(
439
            prompts=prompts,
440
            params=sampling_params,
441
            use_tqdm=use_tqdm,
442
            lora_request=lora_request,
443
444
            priority=priority,
        )
445

446
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
447
        return self.engine_class.validate_outputs(outputs, RequestOutput)
448

449
    def _get_modality_specific_lora_reqs(
450
        self,
451
452
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
453
    ):
454
455
456
457
458
459
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
460
461
        if (
            lora_config is None
462
            or not self.model_config.is_multimodal_model
463
464
            or (lora_config and lora_config.default_mm_loras is None)
        ):
465
466
            return lora_request

467
468
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
469

470
471
472
473
474
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
475
476
477

        return [
            self._resolve_single_prompt_mm_lora(
478
                prompt,
479
480
                opt_lora_req,
                lora_config.default_mm_loras,
481
482
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
483
484
        ]

485
486
487
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
488
489
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
490
491
492
493
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
494
            or not (mm_data := prompt.get("multi_modal_data") or {})
495
        ):
496
497
            return lora_request

498
499
500
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
501
502
503
504
505
506
507
508
509
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
510
511
512
                " will be skipped",
                intersection,
            )
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
528
529
                    "lora_request as we only apply one LoRARequest per prompt"
                )
530
531
532
533
534
535
536
537
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

538
539
    def collective_rpc(
        self,
540
541
        method: str | Callable[..., _R],
        timeout: float | None = None,
542
        args: tuple = (),
543
        kwargs: dict[str, Any] | None = None,
544
    ) -> list[_R]:
545
546
547
548
549
550
551
552
553
554
555
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
556
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
557
558
559
560
561
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
562

563
564
565
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
566
        """
567
568

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
569
570

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
571
        """
572
573
        Run a function directly on the model inside each worker,
        returning the result for each of them.
574
575
576
577
578
579

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
580
        """
581
        return self.llm_engine.apply_model(func)
582

583
584
    def _get_beam_search_lora_requests(
        self,
585
586
587
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
588
        """Get the optional lora request corresponding to each prompt."""
589
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
590
            raise ValueError(
591
592
                "Lora request list should be the same length as the prompts"
            )
593
594
595
596
597
598

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

599
600
    def beam_search(
        self,
601
        prompts: list[TokensPrompt | TextPrompt],
602
        params: BeamSearchParams,
603
        lora_request: list[LoRARequest] | LoRARequest | None = None,
604
        use_tqdm: bool = False,
605
        concurrency_limit: int | None = None,
606
    ) -> list[BeamSearchOutput]:
607
608
609
610
611
612
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
613
            params: The beam search parameters.
614
            lora_request: LoRA request to use for generation, if any.
615
            use_tqdm: Whether to use tqdm to display the progress bar.
616
617
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
618
        """
619
620
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
621
622
623
624
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
625
626
        length_penalty = params.length_penalty

627
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
628

629
630
631
632
633
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
634

635
636
637
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
638
639
                "Disabling progress bar."
            )
640
641
642
643
644
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

645
646
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
647
648
649
650
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
651
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
652
            return TokensPrompt(**token_prompt_kwargs)
653

654
655
656
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
657
658
659
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width, max_tokens=1, temperature=temperature
        )
660
        instances: list[BeamSearchInstance] = []
661

662
        for lora_req, prompt in zip(lora_requests, prompts):
663
664
665
666
667
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
668
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
669

670
671
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
672
673
674
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
675

676
            instances.append(
677
678
679
680
681
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
682
683
                ),
            )
684

685
        for prompt_start in range(0, len(prompts), concurrency_limit):
686
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
687
688
689

            token_iter = range(max_tokens)
            if use_tqdm:
690
691
692
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
693
694
695
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
696
697
                    "reflect instance-level progress."
                )
698
699
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
700
701
                    sum((instance.beams for instance in instances_batch), [])
                )
702
703
                pos = [0] + list(
                    itertools.accumulate(
704
705
706
                        len(instance.beams) for instance in instances_batch
                    )
                )
707
                instance_start_and_end: list[tuple[int, int]] = list(
708
709
                    zip(pos[:-1], pos[1:])
                )
710
711
712
713
714
715

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
716
717
718
719
720
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
721
722
723

                # only runs for one step
                # we don't need to use tqdm here
724
725
726
727
728
729
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
730

731
732
733
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
734
735
736
737
738
739
740
741
742
743
744
745
746
747
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
748
                                    logprobs=current_beam.logprobs + [logprobs],
749
                                    lora_request=current_beam.lora_request,
750
751
752
753
754
755
756
757
758
759
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
760
761
762
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
763
764
765
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
766
                    instance.beams = sorted_beams[:beam_width]
767
768
769
770

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
771
772
773
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
774
775
776
777
778
779
780
781
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

782
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
783
        self,
784
785
786
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
787
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
788
        add_generation_prompt: bool = True,
789
        continue_final_message: bool = False,
790
791
792
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
793
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
794
        """
795
796
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
797

798
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
799
        Returns:
800
801
802
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
803
        """
804
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
805

806
807
        # Handle multi and single conversations
        if is_list_of(messages, list):
808
            # messages is list[list[...]]
809
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
810
        else:
811
            # messages is list[...]
812
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
813

814
        tokenizer = self.get_tokenizer()
815
        model_config = self.model_config
816
817
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
818
            tools,
819
820
            chat_template_content_format,
            tokenizer,
821
            model_config=model_config,
822
823
        )

824
825
826
827
828
829
830
831
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

832
        prompts: list[TokensPrompt] = []
833
834

        for msgs in list_of_messages:
835
836
837
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
838
            conversation, mm_data, mm_uuids = parse_chat_messages(
839
840
841
842
843
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
844
845

            if isinstance(tokenizer, MistralTokenizer):
846
                prompt_token_ids = apply_mistral_chat_template(
847
848
                    tokenizer,
                    messages=msgs,
849
                    **_chat_template_kwargs,
850
851
                )
            else:
852
                prompt_str = apply_hf_chat_template(
853
                    tokenizer=tokenizer,
854
                    conversation=conversation,
855
                    model_config=model_config,
856
                    **_chat_template_kwargs,
857
                )
858
859
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
860
861
862
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
863

864
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
865
866
867
868

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

869
870
871
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

872
873
874
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

875
            prompts.append(prompt)
876

877
878
879
880
        return prompts

    def chat(
        self,
881
882
883
884
885
886
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
887
888
889
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
890
891
892
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
956
        return self.generate(
957
            prompts,
958
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
959
960
961
962
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

963
964
    def encode(
        self,
965
966
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
967
        *,
968
969
970
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
971
        pooling_task: PoolingTask | None = None,
972
        tokenization_kwargs: dict[str, Any] | None = None,
973
    ) -> list[PoolingRequestOutput]:
974
975
        """Apply pooling to the hidden states corresponding to the input
        prompts.
976

977
        This class automatically batches the given prompts, considering
978
979
980
981
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
982
            prompts: The prompts to the LLM. You may pass a sequence of prompts
983
                for batch inference. See [PromptType][vllm.inputs.PromptType]
984
                for more details about the format of each prompt.
985
986
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
987
988
989
990
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
991
            lora_request: LoRA request to use for generation, if any.
992
            pooling_task: Override the pooling task to use.
993
994
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
995
996

        Returns:
997
            A list of `PoolingRequestOutput` objects containing the
998
            pooled hidden states in the same order as the input prompts.
999

1000
1001
1002
1003
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1004
        """
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
        error_str = (
            "pooling_task required for `LLM.encode`\n"
            "Please use one of the more specific methods or set the "
            "pooling_task when using `LLM.encode`:\n"
            "  - For embeddings, use `LLM.embed(...)` "
            'or `pooling_task="embed"`.\n'
            "  - For classification logits, use `LLM.classify(...)` "
            'or `pooling_task="classify"`.\n'
            "  - For similarity scores, use `LLM.score(...)`.\n"
            "  - For rewards, use `LLM.reward(...)` "
            'or `pooling_task="token_classify"`\n'
            "  - For token classification, "
            'use `pooling_task="token_classify"`\n'
            '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
        )
1021

1022
        if pooling_task is None:
1023
            raise ValueError(error_str)
1024

1025
        model_config = self.model_config
1026
        runner_type = model_config.runner_type
1027
        if runner_type != "pooling":
1028
1029
1030
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1031
1032
                "pooling model."
            )
1033

1034
1035
1036
1037
1038
1039
1040
1041
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1042
1043
                    "offline inference example for more details."
                )
1044
1045
1046
1047
1048
1049
1050

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens

1079
        self._validate_and_add_requests(
1080
            prompts=prompts,
1081
            params=pooling_params,
1082
            use_tqdm=use_tqdm,
1083
            lora_request=lora_request,
1084
1085
        )

1086
        outputs = self._run_engine(use_tqdm=use_tqdm)
1087
1088

        model_outputs = self.engine_class.validate_outputs(
1089
1090
            outputs, PoolingRequestOutput
        )
1091
1092
1093
1094
1095

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1096
1097
                model_output=model_outputs
            )
1098
1099

            return [
1100
1101
1102
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1103
1104
1105
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1106
1107
1108
                    prompt_token_ids=[],
                    finished=True,
                )
1109
1110
1111
            ]
        else:
            return model_outputs
1112

1113
1114
    def embed(
        self,
1115
        prompts: PromptType | Sequence[PromptType],
1116
        *,
1117
1118
1119
1120
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1121
    ) -> list[EmbeddingRequestOutput]:
1122
1123
1124
1125
1126
1127
1128
1129
1130
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1131
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1132
                for more details about the format of each prompt.
1133
1134
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1135
1136
1137
1138
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1139
1140
1141
            lora_request: LoRA request to use for generation, if any.

        Returns:
1142
            A list of `EmbeddingRequestOutput` objects containing the
1143
1144
            embedding vectors in the same order as the input prompts.
        """
1145
        if "embed" not in self.supported_tasks:
1146
1147
            raise ValueError(
                "Embedding API is not supported by this model. "
1148
1149
                "Try converting the model using `--convert embed`."
            )
1150

1151
1152
1153
1154
1155
1156
1157
1158
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1159
1160
1161
1162
1163

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1164
        prompts: PromptType | Sequence[PromptType],
1165
        *,
1166
1167
1168
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1169
    ) -> list[ClassificationRequestOutput]:
1170
1171
1172
1173
1174
1175
1176
1177
1178
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1179
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1180
                for more details about the format of each prompt.
1181
1182
1183
1184
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1185
            lora_request: LoRA request to use for generation, if any.
1186
1187
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1188
        Returns:
1189
            A list of `ClassificationRequestOutput` objects containing the
1190
1191
            embedding vectors in the same order as the input prompts.
        """
1192
        if "classify" not in self.supported_tasks:
1193
            raise ValueError(
1194
                "Classification API is not supported by this model. "
1195
1196
                "Try converting the model using `--convert classify`."
            )
1197

1198
1199
1200
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1201
            pooling_params=pooling_params,
1202
1203
1204
            lora_request=lora_request,
            pooling_task="classify",
        )
1205
1206
1207

        return [ClassificationRequestOutput.from_base(item) for item in items]

1208
1209
    def reward(
        self,
1210
        prompts: PromptType | Sequence[PromptType],
1211
1212
        /,
        *,
1213
1214
1215
1216
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1217
1218
1219
1220
1221
1222
1223
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1224
                for more details about the format of each prompt.
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1243
            pooling_task="token_classify",
1244
1245
        )

1246
1247
1248
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1249
1250
1251
1252
1253
1254
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1255
1256
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1257
            text_1 + text_2,
1258
            truncate_prompt_tokens=truncate_prompt_tokens,
1259
1260
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1261
            pooling_params=pooling_params,
1262
1263
            pooling_task="embed",
        )
1264

1265
1266
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1267
1268
1269
1270

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1271
1272
1273
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1274

1275
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1276
1277
1278
1279
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1280
        tokenizer: AnyTokenizer,
1281
1282
1283
1284
1285
1286
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1287
    ) -> list[ScoringRequestOutput]:
1288
        model_config = self.model_config
1289
1290

        if isinstance(tokenizer, MistralTokenizer):
1291
            raise ValueError("Score API is not supported for Mistral tokenizer")
1292

1293
1294
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1295

1296
1297
1298
1299
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1300
        pooling_params_list = list[PoolingParams]()
1301

1302
        tokenization_kwargs: dict[str, Any] = {}
1303

1304
1305
1306
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1307

1308
        prompts = list[PromptType]()
1309

1310
1311
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1312
1313
1314
1315
1316
1317
1318
1319
1320
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1321
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1322
1323
1324
1325
1326
1327
1328
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1329
            prompts.append(engine_prompt)
1330
1331

        self._validate_and_add_requests(
1332
            prompts=prompts,
1333
            params=pooling_params_list,
1334
            use_tqdm=use_tqdm,
1335
1336
1337
1338
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1339
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1340
1341
1342

        return [ScoringRequestOutput.from_base(item) for item in items]

1343
1344
    def score(
        self,
1345
1346
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1347
        /,
1348
        *,
1349
1350
1351
1352
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1353
    ) -> list[ScoringRequestOutput]:
1354
1355
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1356

1357
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1358
1359
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1360
        The input pairs are used to build a list of prompts for the
1361
1362
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1363
1364
1365
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1366
        appropriate multi-modal models. For multi-modal inputs, ensure the
1367
        prompt structure matches the model's expected input format.
1368
1369

        Args:
1370
1371
1372
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1373
                the `data_2` list.
1374
            data_2: The data to pair with the query to form the input to
1375
                the LLM. Can be text or multi-modal data. See [PromptType]
1376
                [vllm.inputs.PromptType] for more details about the format of
1377
                each prompt.
1378
1379
1380
1381
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1382
            lora_request: LoRA request to use for generation, if any.
1383
1384
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1385
        Returns:
1386
            A list of `ScoringRequestOutput` objects containing the
1387
1388
            generated scores in the same order as the input prompts.
        """
1389
        model_config = self.model_config
1390
        runner_type = model_config.runner_type
1391
        if runner_type != "pooling":
1392
1393
1394
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1395
1396
                "pooling model."
            )
1397

1398
1399
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1400
1401
1402
1403
1404
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1405

1406
1407
1408
1409
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1410
            raise ValueError("Score API is only enabled for num_labels == 1.")
1411
1412
1413
1414

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1415
        tokenizer = self.get_tokenizer()
1416

1417
        if not model_config.is_multimodal_model:
1418

1419
            def check_data_type(
1420
1421
1422
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1423
            ):
1424
                if isinstance(data, dict) and "content" in data:
1425
1426
1427
1428
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1429
1430
1431
1432
1433
1434
1435

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1436
1437
1438
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1439
1440
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1441
1442
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1471

1472
        if model_config.is_cross_encoder:
1473
1474
1475
1476
1477
1478
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1479
                pooling_params,
1480
1481
                lora_request,
            )
1482
        else:
1483
1484
            return self._embedding_score(
                tokenizer,
1485
1486
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1487
1488
                truncate_prompt_tokens,
                use_tqdm,
1489
                pooling_params,
1490
1491
                lora_request,
            )
1492

1493
1494
1495
1496
1497
1498
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1499
1500
    def reset_prefix_cache(self, device: Device | None = None) -> None:
        self.llm_engine.reset_prefix_cache(device)
1501

1502
1503
1504
1505
1506
1507
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1508
        Args:
1509
1510
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1511
                is forgotten. Level 1 sleep is good for sleeping and waking
1512
1513
1514
1515
1516
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1517
                sleep is good for sleeping and waking up the engine to run a
1518
                different model or update the model, where previous model
1519
                weights are not needed. It reduces CPU memory pressure.
1520
        """
1521
        self.reset_prefix_cache()
1522
1523
        self.llm_engine.sleep(level=level)

1524
    def wake_up(self, tags: list[str] | None = None):
1525
        """
1526
1527
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1528

1529
        Args:
1530
1531
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1532
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1533
                wake_up should be called with all tags (or None) before the
1534
1535
1536
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1537

1538
1539
1540
1541
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1542
            A `MetricSnapshot` instance capturing the current state
1543
1544
1545
1546
1547
1548
1549
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1550
1551
    def _validate_and_add_requests(
        self,
1552
1553
1554
1555
1556
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1557
        *,
1558
1559
1560
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1561
    ) -> None:
1562
        if isinstance(prompts, (str, dict)):
1563
            # Convert a single prompt to a list.
1564
            prompts = [prompts]  # type: ignore[list-item]
1565

1566
        num_requests = len(prompts)
1567
        if isinstance(params, Sequence) and len(params) != num_requests:
1568
1569
1570
1571
1572
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )
1573
1574
1575
1576
1577
1578
        if priority is not None and len(priority) != num_requests:
            raise ValueError(
                "The lengths of prompts "
                f"({num_requests}) and priority ({len(priority)}) "
                "must be the same."
            )
1579
1580

        for sp in params if isinstance(params, Sequence) else (params,):
1581
1582
1583
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1584

Zhuohan Li's avatar
Zhuohan Li committed
1585
        # Add requests to the engine.
1586
1587
        it = prompts
        if use_tqdm:
1588
1589
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1590

1591
        added_request_ids: list[str] = []
1592

1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
        try:
            for i, prompt in enumerate(it):
                if isinstance(prompt, dict):
                    self._validate_mm_data_and_uuids(
                        prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                    )
                request_id = self._add_request(
                    prompt,
                    params[i] if isinstance(params, Sequence) else params,
                    lora_request=lora_request[i]
                    if isinstance(lora_request, Sequence)
                    else lora_request,
                    priority=priority[i] if priority else 0,
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
                self.llm_engine.abort_request(added_request_ids)
            raise e
1612

1613
    def _validate_mm_data_and_uuids(
1614
        self,
1615
1616
        multi_modal_data: Any | None,  # MultiModalDataDict
        multi_modal_uuids: Any | None,  # MultiModalUUIDDict
1617
1618
1619
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1620
        then its corresponding UUID must be set.
1621
1622
1623
1624
1625
1626
1627
1628
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1629
1630
1631
1632
1633
1634
1635
1636
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1637
1638
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1639
1640
                                f"but UUID is not provided"
                            )
1641
                        else:
1642
1643
1644
1645
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1646
1647
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1648
1649
                                    f"but UUID is not provided"
                                )
1650
            else:
1651
1652
1653
1654
1655
1656
1657
1658
1659
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1660

1661
1662
1663
1664
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1665
        params: SamplingParams | PoolingParams,
1666
        *,
1667
        lora_request: LoRARequest | None,
1668
1669
1670
1671
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
        tokenization_kwargs: dict[str, Any] = {}
1672
1673
1674
1675
1676
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1677

1678
        engine_request = self.processor.process_inputs(
1679
1680
1681
1682
1683
1684
1685
1686
1687
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1688
    def _add_request(
nunjunj's avatar
nunjunj committed
1689
        self,
1690
        prompt: PromptType,
1691
1692
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1693
        priority: int = 0,
1694
    ) -> str:
1695
        prompt_text, _, _ = get_prompt_components(prompt)
1696
        request_id = str(next(self.request_counter))
1697
1698

        engine_request, tokenization_kwargs = self._process_inputs(
1699
            request_id,
1700
            prompt,
1701
1702
            params,
            lora_request=lora_request,
1703
1704
1705
1706
1707
1708
1709
1710
            priority=priority,
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1711
            tokenization_kwargs=tokenization_kwargs,
1712
            priority=priority,
1713
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1714
        )
1715
        return request_id
1716

1717
    def _run_engine(
1718
1719
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1720
1721
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1722
            num_requests = self.llm_engine.get_num_unfinished_requests()
1723
1724
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1725
1726
1727
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1728
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1729
            )
1730

Zhuohan Li's avatar
Zhuohan Li committed
1731
        # Run the engine.
1732
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1733
1734
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1735
1736
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1737
            for output in step_outputs:
1738
                if output.finished:
1739
1740
                    outputs.append(output)
                    if use_tqdm:
1741
1742
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1743
                            n = len(output.outputs)
1744
                            assert output.prompt_token_ids is not None
1745
                            total_in_toks += len(output.prompt_token_ids) * n
1746
1747
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1748
1749
1750
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1751
1752
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1753
1754
                                f"output: {out_spd:.2f} toks/s"
                            )
1755
                            pbar.update(n)
1756
1757
                        else:
                            pbar.update(1)
1758
1759
                        if pbar.n == num_requests:
                            pbar.refresh()
1760

1761
1762
        if use_tqdm:
            pbar.close()
1763
1764
1765
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1766
        return sorted(outputs, key=lambda x: int(x.request_id))