llm.py 73.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
    CompilationConfig,
22
    PoolerConfig,
23
24
25
    StructuredOutputsConfig,
    is_init_field,
)
26
from vllm.config.compilation import CompilationMode
27
from vllm.config.model import (
28
29
    ConvertOption,
    HfOverrides,
30
    ModelDType,
31
32
    RunnerOption,
)
33
from vllm.config.renderer import TokenizerMode
34
from vllm.engine.arg_utils import EngineArgs
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
59
from vllm.inputs.parse import get_prompt_components
60
from vllm.logger import init_logger
61
from vllm.lora.request import LoRARequest
62
from vllm.model_executor.layers.quantization import QuantizationMethods
63
64
65
66
67
68
69
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
70
from vllm.platforms import current_platform
71
from vllm.pooling_params import PoolingParams
72
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
73
from vllm.tasks import PoolingTask
74
from vllm.tokenizers import MistralTokenizer, TokenizerLike
75
from vllm.tokenizers.hf import get_cached_tokenizer
yhu422's avatar
yhu422 committed
76
from vllm.usage.usage_lib import UsageContext
77
from vllm.utils.collection_utils import as_iter, is_list_of
78
from vllm.utils.counter import Counter
79
from vllm.v1.engine import EngineCoreRequest
80
from vllm.v1.engine.llm_engine import LLMEngine
81
from vllm.v1.sample.logits_processor import LogitsProcessor
82

83
84
85
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

86
87
logger = init_logger(__name__)

88
89
_R = TypeVar("_R", default=Any)

90
91

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
97
98
99
100
101
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
102
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
103
104
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
105
106
107
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
108
109
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
110
111
112
113
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
114
        allowed_media_domains: If set, only media URLs that belong to this
115
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
119
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
120
121
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
122
        quantization: The method used to quantize the model weights. Currently,
123
            we support "awq", "gptq", and "fp8" (experimental).
124
125
126
127
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
128
129
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
130
131
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
132
133
134
135
136
137
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
138
139
140
141
142
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
143
            compared with using gpu_memory_utilization. Note that
144
145
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
146
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
147
148
149
150
151
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
152
153
154
155
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
156
157
158
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
159
160
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
161
        hf_token: The token to use as HTTP bearer authorization for remote files
162
            . If `True`, will use the token generated when running
163
            `huggingface-cli login` (stored in `~/.huggingface`).
164
165
166
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
167
168
169
170
171
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
172
173
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
174
        compilation_config: Either an integer or a dictionary. If it is an
175
            integer, it is used as the mode of compilation optimization. If it
176
            is a dictionary, it can specify the full compilation configuration.
177
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
178

179
180
    Note:
        This class is intended to be used for offline inference. For online
181
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
182
    """
183
184
185
186

    def __init__(
        self,
        model: str,
187
        *,
188
189
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
190
        tokenizer: str | None = None,
191
        tokenizer_mode: TokenizerMode | str = "auto",
192
        skip_tokenizer_init: bool = False,
193
        trust_remote_code: bool = False,
194
        allowed_local_media_path: str = "",
195
        allowed_media_domains: list[str] | None = None,
196
        tensor_parallel_size: int = 1,
197
        dtype: ModelDType = "auto",
198
199
200
201
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
        seed: int | None = None,
202
        gpu_memory_utilization: float = 0.9,
203
        swap_space: float = 4,
204
        cpu_offload_gb: float = 0,
205
        enforce_eager: bool = False,
206
        disable_custom_all_reduce: bool = False,
207
208
209
210
211
212
213
214
215
216
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
217
        **kwargs: Any,
218
    ) -> None:
219
        """LLM constructor."""
220

221
222
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
223

224
225
226
227
228
229
230
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

231
        if "kv_transfer_config" in kwargs and isinstance(
232
233
            kwargs["kv_transfer_config"], dict
        ):
234
            from vllm.config.kv_transfer import KVTransferConfig
235

236
237
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
238
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
239
240
241
242
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
243
244
245
                    raw_config_dict,
                    e,
                )
246
247
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
248
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
249

250
251
252
        if hf_overrides is None:
            hf_overrides = {}

253
        if compilation_config is not None:
254
            if isinstance(compilation_config, int):
255
256
257
                compilation_config_instance = CompilationConfig(
                    mode=CompilationMode(compilation_config)
                )
258
259
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
260
261
262
263
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
264
265
                    }
                )
266
267
            else:
                compilation_config_instance = compilation_config
268
        else:
269
            compilation_config_instance = CompilationConfig()
270

271
272
273
274
275
276
277
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
278
279
                    }
                )
280
281
282
283
284
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

285
        # warn about single-process data parallel usage.
286
287
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
288
289
290
291
292
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
293
            raise ValueError(
294
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
295
296
297
298
299
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
300
        engine_args = EngineArgs(
301
            model=model,
302
303
            runner=runner,
            convert=convert,
304
            tokenizer=tokenizer,
305
            tokenizer_mode=tokenizer_mode,
306
            skip_tokenizer_init=skip_tokenizer_init,
307
            trust_remote_code=trust_remote_code,
308
            allowed_local_media_path=allowed_local_media_path,
309
            allowed_media_domains=allowed_media_domains,
310
311
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
312
            quantization=quantization,
313
            revision=revision,
314
            tokenizer_revision=tokenizer_revision,
315
316
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
317
            kv_cache_memory_bytes=kv_cache_memory_bytes,
318
            swap_space=swap_space,
319
            cpu_offload_gb=cpu_offload_gb,
320
            enforce_eager=enforce_eager,
321
            disable_custom_all_reduce=disable_custom_all_reduce,
322
            hf_token=hf_token,
323
            hf_overrides=hf_overrides,
324
            mm_processor_kwargs=mm_processor_kwargs,
325
            pooler_config=pooler_config,
326
            structured_outputs_config=structured_outputs_instance,
327
            compilation_config=compilation_config_instance,
328
            logits_processors=logits_processors,
329
330
            **kwargs,
        )
331

332
333
        log_non_default_args(engine_args)

334
        self.llm_engine = LLMEngine.from_engine_args(
335
336
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
337
        self.engine_class = type(self.llm_engine)
338

339
        self.request_counter = Counter()
340
        self.default_sampling_params: dict[str, Any] | None = None
341

342
343
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
344
345
        self.supported_tasks = supported_tasks

346
        self.renderer_config = self.llm_engine.renderer_config
347
        self.model_config = self.llm_engine.model_config
348
        self.input_processor = self.llm_engine.input_processor
349
        self.io_processor = self.llm_engine.io_processor
350

351
    def get_tokenizer(self) -> TokenizerLike:
352
        return self.llm_engine.get_tokenizer()
353

354
    @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
355
    def set_tokenizer(self, tokenizer: TokenizerLike) -> None:
356
357
358
359
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
360
            self.llm_engine.tokenizer = tokenizer
361
        else:
362
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
363

364
    def reset_mm_cache(self) -> None:
365
        self.input_processor.clear_mm_cache()
366
367
        self.llm_engine.reset_mm_cache()

368
    def get_default_sampling_params(self) -> SamplingParams:
369
        if self.default_sampling_params is None:
370
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
371
372
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
373
374
        return SamplingParams()

375
376
    def generate(
        self,
377
378
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
379
        *,
380
381
382
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
383
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
384
385
        """Generates the completions for the input prompts.

386
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388
389
390
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
391
            prompts: The prompts to the LLM. You may pass a sequence of prompts
392
                for batch inference. See [PromptType][vllm.inputs.PromptType]
393
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
394
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
395
396
397
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
398
                prompts and it is paired one by one with the prompt.
399
400
401
402
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
403
            lora_request: LoRA request to use for generation, if any.
404
405
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
406
407
408
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
Woosuk Kwon's avatar
Woosuk Kwon committed
409
410

        Returns:
411
            A list of `RequestOutput` objects containing the
412
            generated completions in the same order as the input prompts.
413

414
415
416
417
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
418
        """
419
        model_config = self.model_config
420
421
        runner_type = model_config.runner_type
        if runner_type != "generate":
422
423
424
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
425
426
                "generative model."
            )
427

428
429
        if sampling_params is None:
            # Use default sampling params.
430
            sampling_params = self.get_default_sampling_params()
431

432
        # Add any modality specific loras to the corresponding prompts
433
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
434

435
        self._validate_and_add_requests(
436
            prompts=prompts,
437
            params=sampling_params,
438
            use_tqdm=use_tqdm,
439
            lora_request=lora_request,
440
441
            priority=priority,
        )
442

443
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
444
        return self.engine_class.validate_outputs(outputs, RequestOutput)
445

446
    def _get_modality_specific_lora_reqs(
447
        self,
448
449
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
450
    ):
451
452
453
454
455
456
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
457
458
        if (
            lora_config is None
459
            or not self.model_config.is_multimodal_model
460
461
            or (lora_config and lora_config.default_mm_loras is None)
        ):
462
463
            return lora_request

464
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
465
            prompts = [prompts]
466

467
468
469
470
471
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
472
473
474

        return [
            self._resolve_single_prompt_mm_lora(
475
                prompt,
476
477
                opt_lora_req,
                lora_config.default_mm_loras,
478
479
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
480
481
        ]

482
483
484
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
485
486
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
487
488
489
490
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
491
            or not (mm_data := prompt.get("multi_modal_data") or {})
492
        ):
493
494
            return lora_request

495
496
497
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
498
499
500
501
502
503
504
505
506
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
507
508
509
                " will be skipped",
                intersection,
            )
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
525
526
                    "lora_request as we only apply one LoRARequest per prompt"
                )
527
528
529
530
531
532
533
534
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

535
536
    def collective_rpc(
        self,
537
538
        method: str | Callable[..., _R],
        timeout: float | None = None,
539
        args: tuple = (),
540
        kwargs: dict[str, Any] | None = None,
541
    ) -> list[_R]:
542
543
544
545
546
547
548
549
550
551
552
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
553
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
554
555
556
557
558
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
559

560
561
562
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
563
        """
564
565

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
566
567

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
568
        """
569
570
        Run a function directly on the model inside each worker,
        returning the result for each of them.
571
572
573
574
575
576

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
577
        """
578
        return self.llm_engine.apply_model(func)
579

580
581
    def _get_beam_search_lora_requests(
        self,
582
583
584
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
585
        """Get the optional lora request corresponding to each prompt."""
586
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
587
            raise ValueError(
588
589
                "Lora request list should be the same length as the prompts"
            )
590
591
592
593
594
595

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

596
597
    def beam_search(
        self,
598
        prompts: list[TokensPrompt | TextPrompt],
599
        params: BeamSearchParams,
600
        lora_request: list[LoRARequest] | LoRARequest | None = None,
601
        use_tqdm: bool = False,
602
        concurrency_limit: int | None = None,
603
    ) -> list[BeamSearchOutput]:
604
605
606
607
608
609
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
610
            params: The beam search parameters.
611
            lora_request: LoRA request to use for generation, if any.
612
            use_tqdm: Whether to use tqdm to display the progress bar.
613
614
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
615
        """
616
617
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
618
619
620
621
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
622
623
        length_penalty = params.length_penalty

624
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
625

626
627
628
629
630
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
631

632
633
634
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
635
636
                "Disabling progress bar."
            )
637
638
639
640
641
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

642
643
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
644
645
646
647
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
648
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
649
            return TokensPrompt(**token_prompt_kwargs)
650

651
652
653
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
654
655
656
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width, max_tokens=1, temperature=temperature
        )
657
        instances: list[BeamSearchInstance] = []
658

659
        for lora_req, prompt in zip(lora_requests, prompts):
660
661
662
663
664
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
665
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
666

667
668
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
669
670
671
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
672

673
            instances.append(
674
675
676
677
678
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
679
680
                ),
            )
681

682
        for prompt_start in range(0, len(prompts), concurrency_limit):
683
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
684
685
686

            token_iter = range(max_tokens)
            if use_tqdm:
687
688
689
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
690
691
692
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
693
694
                    "reflect instance-level progress."
                )
695
696
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
697
698
                    sum((instance.beams for instance in instances_batch), [])
                )
699
700
                pos = [0] + list(
                    itertools.accumulate(
701
702
703
                        len(instance.beams) for instance in instances_batch
                    )
                )
704
                instance_start_and_end: list[tuple[int, int]] = list(
705
706
                    zip(pos[:-1], pos[1:])
                )
707
708
709
710
711
712

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
713
714
715
716
717
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
718
719
720

                # only runs for one step
                # we don't need to use tqdm here
721
722
723
724
725
726
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
727

728
729
730
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
731
732
733
734
735
736
737
738
739
740
741
742
743
744
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
745
                                    logprobs=current_beam.logprobs + [logprobs],
746
                                    lora_request=current_beam.lora_request,
747
748
749
750
751
752
753
754
755
756
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
757
758
759
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
760
761
762
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
763
                    instance.beams = sorted_beams[:beam_width]
764
765
766
767

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
768
769
770
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
771
772
773
774
775
776
777
778
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

779
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
780
        self,
781
782
783
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
784
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
785
        add_generation_prompt: bool = True,
786
        continue_final_message: bool = False,
787
788
789
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
790
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
791
        """
792
793
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
794

795
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
796
        Returns:
797
798
799
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
800
        """
801
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
802

803
804
        # Handle multi and single conversations
        if is_list_of(messages, list):
805
            # messages is list[list[...]]
806
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
807
        else:
808
            # messages is list[...]
809
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
810

811
        tokenizer = self.get_tokenizer()
812
        renderer_config = self.renderer_config
813
814
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
815
            tools,
816
817
            chat_template_content_format,
            tokenizer,
818
            renderer_config=renderer_config,
819
820
        )

821
822
823
824
825
826
827
828
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

829
        prompts: list[TokensPrompt] = []
830
831

        for msgs in list_of_messages:
832
833
834
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
835
            conversation, mm_data, mm_uuids = parse_chat_messages(
836
                msgs,
837
                renderer_config,
838
839
                content_format=resolved_content_format,
            )
840
841

            if isinstance(tokenizer, MistralTokenizer):
842
                prompt_token_ids = apply_mistral_chat_template(
843
844
                    tokenizer,
                    messages=msgs,
845
                    **_chat_template_kwargs,
846
847
                )
            else:
848
                prompt_str = apply_hf_chat_template(
849
                    tokenizer=tokenizer,
850
                    conversation=conversation,
851
                    renderer_config=renderer_config,
852
                    **_chat_template_kwargs,
853
                )
854
855
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
856
857
858
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
859

860
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
861
862
863
864

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

865
866
867
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

868
869
870
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

871
            prompts.append(prompt)
872

873
874
875
876
        return prompts

    def chat(
        self,
877
878
879
880
881
882
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
883
884
885
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
886
887
888
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
952
        return self.generate(
953
            prompts,
954
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
955
956
957
958
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

959
960
    def encode(
        self,
961
962
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
963
        *,
964
965
966
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
967
        pooling_task: PoolingTask | None = None,
968
        tokenization_kwargs: dict[str, Any] | None = None,
969
    ) -> list[PoolingRequestOutput]:
970
971
        """Apply pooling to the hidden states corresponding to the input
        prompts.
972

973
        This class automatically batches the given prompts, considering
974
975
976
977
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
978
            prompts: The prompts to the LLM. You may pass a sequence of prompts
979
                for batch inference. See [PromptType][vllm.inputs.PromptType]
980
                for more details about the format of each prompt.
981
982
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
983
984
985
986
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
987
            lora_request: LoRA request to use for generation, if any.
988
            pooling_task: Override the pooling task to use.
989
990
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
991
992

        Returns:
993
            A list of `PoolingRequestOutput` objects containing the
994
            pooled hidden states in the same order as the input prompts.
995

996
997
998
999
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1000
        """
1001

1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
        error_str = (
            "pooling_task required for `LLM.encode`\n"
            "Please use one of the more specific methods or set the "
            "pooling_task when using `LLM.encode`:\n"
            "  - For embeddings, use `LLM.embed(...)` "
            'or `pooling_task="embed"`.\n'
            "  - For classification logits, use `LLM.classify(...)` "
            'or `pooling_task="classify"`.\n'
            "  - For similarity scores, use `LLM.score(...)`.\n"
            "  - For rewards, use `LLM.reward(...)` "
            'or `pooling_task="token_classify"`\n'
            "  - For token classification, "
            'use `pooling_task="token_classify"`\n'
            '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
        )
1017

1018
        if pooling_task is None:
1019
            raise ValueError(error_str)
1020

1021
        model_config = self.model_config
1022
        runner_type = model_config.runner_type
1023
        if runner_type != "pooling":
1024
1025
1026
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1027
1028
                "pooling model."
            )
1029

1030
1031
1032
1033
1034
1035
1036
1037
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1038
1039
                    "offline inference example for more details."
                )
1040
1041
1042
1043
1044
1045
1046

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens

1075
        self._validate_and_add_requests(
1076
            prompts=prompts,
1077
            params=pooling_params,
1078
            use_tqdm=use_tqdm,
1079
            lora_request=lora_request,
1080
            tokenization_kwargs=tokenization_kwargs,
1081
1082
        )

1083
        outputs = self._run_engine(use_tqdm=use_tqdm)
1084
1085

        model_outputs = self.engine_class.validate_outputs(
1086
1087
            outputs, PoolingRequestOutput
        )
1088
1089
1090
1091
1092

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1093
1094
                model_output=model_outputs
            )
1095
1096

            return [
1097
1098
1099
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1100
1101
1102
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1103
1104
1105
                    prompt_token_ids=[],
                    finished=True,
                )
1106
1107
1108
            ]
        else:
            return model_outputs
1109

1110
1111
    def embed(
        self,
1112
        prompts: PromptType | Sequence[PromptType],
1113
        *,
1114
1115
1116
1117
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1118
        tokenization_kwargs: dict[str, Any] | None = None,
1119
    ) -> list[EmbeddingRequestOutput]:
1120
1121
1122
1123
1124
1125
1126
1127
1128
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1129
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1130
                for more details about the format of each prompt.
1131
1132
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1133
1134
1135
1136
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1137
1138
1139
            lora_request: LoRA request to use for generation, if any.

        Returns:
1140
            A list of `EmbeddingRequestOutput` objects containing the
1141
1142
            embedding vectors in the same order as the input prompts.
        """
1143
        if "embed" not in self.supported_tasks:
1144
1145
            raise ValueError(
                "Embedding API is not supported by this model. "
1146
1147
                "Try converting the model using `--convert embed`."
            )
1148

1149
1150
1151
1152
1153
1154
1155
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1156
            tokenization_kwargs=tokenization_kwargs,
1157
        )
1158
1159
1160
1161
1162

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1163
        prompts: PromptType | Sequence[PromptType],
1164
        *,
1165
1166
1167
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1168
        tokenization_kwargs: dict[str, Any] | None = None,
1169
    ) -> list[ClassificationRequestOutput]:
1170
1171
1172
1173
1174
1175
1176
1177
1178
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1179
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1180
                for more details about the format of each prompt.
1181
1182
1183
1184
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1185
            lora_request: LoRA request to use for generation, if any.
1186
1187
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1188
        Returns:
1189
            A list of `ClassificationRequestOutput` objects containing the
1190
1191
            embedding vectors in the same order as the input prompts.
        """
1192
        if "classify" not in self.supported_tasks:
1193
            raise ValueError(
1194
                "Classification API is not supported by this model. "
1195
1196
                "Try converting the model using `--convert classify`."
            )
1197

1198
1199
1200
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1201
            pooling_params=pooling_params,
1202
1203
            lora_request=lora_request,
            pooling_task="classify",
1204
            tokenization_kwargs=tokenization_kwargs,
1205
        )
1206
1207
1208

        return [ClassificationRequestOutput.from_base(item) for item in items]

1209
1210
    def reward(
        self,
1211
        prompts: PromptType | Sequence[PromptType],
1212
1213
        /,
        *,
1214
1215
1216
1217
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1218
        tokenization_kwargs: dict[str, Any] | None = None,
1219
1220
1221
1222
1223
1224
1225
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1226
                for more details about the format of each prompt.
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1245
            pooling_task="token_classify",
1246
            tokenization_kwargs=tokenization_kwargs,
1247
1248
        )

1249
1250
    def _embedding_score(
        self,
1251
        tokenizer: TokenizerLike,
1252
1253
1254
1255
1256
1257
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1258
        tokenization_kwargs: dict[str, Any] | None = None,
1259
1260
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1261
            text_1 + text_2,
1262
            truncate_prompt_tokens=truncate_prompt_tokens,
1263
1264
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1265
            pooling_params=pooling_params,
1266
            pooling_task="embed",
1267
            tokenization_kwargs=tokenization_kwargs,
1268
        )
1269

1270
1271
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1272
1273
1274
1275

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1276
1277
1278
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1279

1280
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1281
1282
1283
1284
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1285
        tokenizer: TokenizerLike,
1286
1287
1288
1289
1290
1291
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1292
        tokenization_kwargs: dict[str, Any] | None = None,
1293
    ) -> list[ScoringRequestOutput]:
1294
        renderer_config = self.renderer_config
1295
        model_config = self.model_config
1296
1297

        if isinstance(tokenizer, MistralTokenizer):
1298
            raise ValueError("Score API is not supported for Mistral tokenizer")
1299

1300
1301
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1302

1303
1304
1305
1306
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1307
        pooling_params_list = list[PoolingParams]()
1308

1309
1310
        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1311

1312
1313
1314
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1315

1316
        prompts = list[PromptType]()
1317

1318
1319
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1320
1321
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1322
                renderer_config=renderer_config,
1323
1324
1325
1326
1327
1328
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1329
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1330
1331
1332
1333
1334
1335
1336
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1337
            prompts.append(engine_prompt)
1338
1339

        self._validate_and_add_requests(
1340
            prompts=prompts,
1341
            params=pooling_params_list,
1342
            use_tqdm=use_tqdm,
1343
1344
1345
1346
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1347
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1348
1349
1350

        return [ScoringRequestOutput.from_base(item) for item in items]

1351
1352
    def score(
        self,
1353
1354
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1355
        /,
1356
        *,
1357
1358
1359
1360
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1361
    ) -> list[ScoringRequestOutput]:
1362
1363
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1364

1365
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1366
1367
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1368
        The input pairs are used to build a list of prompts for the
1369
1370
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1371
1372
1373
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1374
        appropriate multi-modal models. For multi-modal inputs, ensure the
1375
        prompt structure matches the model's expected input format.
1376
1377

        Args:
1378
1379
1380
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1381
                the `data_2` list.
1382
            data_2: The data to pair with the query to form the input to
1383
                the LLM. Can be text or multi-modal data. See [PromptType]
1384
                [vllm.inputs.PromptType] for more details about the format of
1385
                each prompt.
1386
1387
1388
1389
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1390
            lora_request: LoRA request to use for generation, if any.
1391
1392
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1393
        Returns:
1394
            A list of `ScoringRequestOutput` objects containing the
1395
1396
            generated scores in the same order as the input prompts.
        """
1397
        model_config = self.model_config
1398
        runner_type = model_config.runner_type
1399
        if runner_type != "pooling":
1400
1401
1402
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1403
1404
                "pooling model."
            )
1405

1406
1407
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1408
1409
1410
1411
1412
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1413

1414
1415
1416
1417
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1418
            raise ValueError("Score API is only enabled for num_labels == 1.")
1419
1420
1421
1422

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1423
        tokenizer = self.get_tokenizer()
1424

1425
        if not model_config.is_multimodal_model:
1426

1427
            def check_data_type(
1428
1429
1430
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1431
            ):
1432
                if isinstance(data, dict) and "content" in data:
1433
1434
1435
1436
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1437
1438
1439
1440
1441
1442
1443

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1444
1445
1446
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1447
1448
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1449
1450
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1479

1480
        if model_config.is_cross_encoder:
1481
1482
1483
1484
1485
1486
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1487
                pooling_params,
1488
1489
                lora_request,
            )
1490
        else:
1491
1492
            return self._embedding_score(
                tokenizer,
1493
1494
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1495
1496
                truncate_prompt_tokens,
                use_tqdm,
1497
                pooling_params,
1498
1499
                lora_request,
            )
1500

1501
1502
1503
1504
1505
1506
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1507
1508
1509
1510
1511
1512
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1513

1514
1515
1516
1517
1518
1519
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1520
        Args:
1521
1522
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1523
                is forgotten. Level 1 sleep is good for sleeping and waking
1524
1525
1526
1527
1528
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1529
                sleep is good for sleeping and waking up the engine to run a
1530
                different model or update the model, where previous model
1531
                weights are not needed. It reduces CPU memory pressure.
1532
        """
1533
        self.reset_prefix_cache()
1534
1535
        self.llm_engine.sleep(level=level)

1536
    def wake_up(self, tags: list[str] | None = None):
1537
        """
1538
1539
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1540

1541
        Args:
1542
1543
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1544
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1545
                wake_up should be called with all tags (or None) before the
1546
1547
1548
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1549

1550
1551
1552
1553
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1554
            A `MetricSnapshot` instance capturing the current state
1555
1556
1557
1558
1559
1560
1561
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1562
1563
    def _validate_and_add_requests(
        self,
1564
1565
1566
1567
1568
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1569
        *,
1570
1571
1572
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1573
        tokenization_kwargs: dict[str, Any] | None = None,
1574
    ) -> None:
1575
        if isinstance(prompts, (str, dict)):
1576
            # Convert a single prompt to a list.
1577
            prompts = [prompts]  # type: ignore[list-item]
1578

1579
        num_requests = len(prompts)
1580
        if isinstance(params, Sequence) and len(params) != num_requests:
1581
1582
1583
1584
1585
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )
1586
1587
1588
1589
1590
1591
        if priority is not None and len(priority) != num_requests:
            raise ValueError(
                "The lengths of prompts "
                f"({num_requests}) and priority ({len(priority)}) "
                "must be the same."
            )
1592
1593

        for sp in params if isinstance(params, Sequence) else (params,):
1594
1595
1596
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1597

Zhuohan Li's avatar
Zhuohan Li committed
1598
        # Add requests to the engine.
1599
1600
        it = prompts
        if use_tqdm:
1601
1602
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1603

1604
        added_request_ids: list[str] = []
1605

1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
        try:
            for i, prompt in enumerate(it):
                if isinstance(prompt, dict):
                    self._validate_mm_data_and_uuids(
                        prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                    )
                request_id = self._add_request(
                    prompt,
                    params[i] if isinstance(params, Sequence) else params,
                    lora_request=lora_request[i]
                    if isinstance(lora_request, Sequence)
                    else lora_request,
                    priority=priority[i] if priority else 0,
1619
                    tokenization_kwargs=tokenization_kwargs,
1620
1621
1622
1623
1624
1625
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
                self.llm_engine.abort_request(added_request_ids)
            raise e
1626

1627
    def _validate_mm_data_and_uuids(
1628
        self,
1629
1630
        multi_modal_data: Any | None,  # MultiModalDataDict
        multi_modal_uuids: Any | None,  # MultiModalUUIDDict
1631
1632
1633
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1634
        then its corresponding UUID must be set.
1635
1636
1637
1638
1639
1640
1641
1642
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1643
1644
1645
1646
1647
1648
1649
1650
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1651
1652
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1653
1654
                                f"but UUID is not provided"
                            )
1655
                        else:
1656
1657
1658
1659
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1660
1661
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1662
1663
                                    f"but UUID is not provided"
                                )
1664
            else:
1665
1666
1667
1668
1669
1670
1671
1672
1673
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1674

1675
1676
1677
1678
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1679
        params: SamplingParams | PoolingParams,
1680
        *,
1681
        lora_request: LoRARequest | None,
1682
        priority: int,
1683
        tokenization_kwargs: dict[str, Any] | None = None,
1684
1685
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
1686
1687
1688

        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1689
1690
1691
1692
1693
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1694

1695
        engine_request = self.input_processor.process_inputs(
1696
1697
1698
1699
1700
1701
1702
1703
1704
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1705
    def _add_request(
nunjunj's avatar
nunjunj committed
1706
        self,
1707
        prompt: PromptType,
1708
1709
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1710
        priority: int = 0,
1711
        tokenization_kwargs: dict[str, Any] | None = None,
1712
    ) -> str:
1713
        prompt_text, _, _ = get_prompt_components(prompt)
1714
        request_id = str(next(self.request_counter))
1715
1716

        engine_request, tokenization_kwargs = self._process_inputs(
1717
            request_id,
1718
            prompt,
1719
1720
            params,
            lora_request=lora_request,
1721
            priority=priority,
1722
            tokenization_kwargs=tokenization_kwargs,
1723
1724
1725
1726
1727
1728
1729
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1730
            tokenization_kwargs=tokenization_kwargs,
1731
            priority=priority,
1732
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1733
        )
1734
        return request_id
1735

1736
    def _run_engine(
1737
1738
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1739
1740
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1741
            num_requests = self.llm_engine.get_num_unfinished_requests()
1742
1743
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1744
1745
1746
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1747
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1748
            )
1749

Zhuohan Li's avatar
Zhuohan Li committed
1750
        # Run the engine.
1751
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1752
1753
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1754
1755
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1756
            for output in step_outputs:
1757
                if output.finished:
1758
1759
                    outputs.append(output)
                    if use_tqdm:
1760
1761
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1762
                            n = len(output.outputs)
1763
                            assert output.prompt_token_ids is not None
1764
                            total_in_toks += len(output.prompt_token_ids) * n
1765
1766
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1767
1768
1769
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1770
1771
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1772
1773
                                f"output: {out_spd:.2f} toks/s"
                            )
1774
                            pbar.update(n)
1775
1776
                        else:
                            pbar.update(1)
1777
1778
                        if pbar.n == num_requests:
                            pbar.refresh()
1779

1780
1781
        if use_tqdm:
            pbar.close()
1782
1783
1784
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1785
        return sorted(outputs, key=lambda x: int(x.request_id))