llm.py 72.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
    CompilationConfig,
22
    PoolerConfig,
23
24
25
    StructuredOutputsConfig,
    is_init_field,
)
26
from vllm.config.compilation import CompilationMode
27
from vllm.config.model import (
28
29
    ConvertOption,
    HfOverrides,
30
    ModelDType,
31
    RunnerOption,
32
    TokenizerMode,
33
)
34
from vllm.engine.arg_utils import EngineArgs
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
59
from vllm.inputs.parse import get_prompt_components
60
from vllm.logger import init_logger
61
from vllm.lora.request import LoRARequest
62
from vllm.model_executor.layers.quantization import QuantizationMethods
63
64
65
66
67
68
69
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
70
from vllm.platforms import current_platform
71
from vllm.pooling_params import PoolingParams
72
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
73
from vllm.tasks import PoolingTask
74
75
76
77
78
from vllm.transformers_utils.tokenizer import (
    AnyTokenizer,
    MistralTokenizer,
    get_cached_tokenizer,
)
yhu422's avatar
yhu422 committed
79
from vllm.usage.usage_lib import UsageContext
80
from vllm.utils.collection_utils import as_iter, is_list_of
81
from vllm.utils.counter import Counter
82
from vllm.v1.engine import EngineCoreRequest
83
from vllm.v1.engine.llm_engine import LLMEngine
84
from vllm.v1.sample.logits_processor import LogitsProcessor
85

86
87
88
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

89
90
logger = init_logger(__name__)

91
92
_R = TypeVar("_R", default=Any)

93
94

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
95
96
97
98
99
100
101
102
103
104
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
105
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
106
107
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
108
109
110
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
111
112
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
113
114
115
116
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
117
        allowed_media_domains: If set, only media URLs that belong to this
118
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
122
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
123
124
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
125
        quantization: The method used to quantize the model weights. Currently,
126
            we support "awq", "gptq", and "fp8" (experimental).
127
128
129
130
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
131
132
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
133
134
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
135
136
137
138
139
140
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
141
142
143
144
145
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
146
            compared with using gpu_memory_utilization. Note that
147
148
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
149
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
150
151
152
153
154
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
155
156
157
158
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
159
160
161
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
162
163
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
164
        hf_token: The token to use as HTTP bearer authorization for remote files
165
            . If `True`, will use the token generated when running
166
            `huggingface-cli login` (stored in `~/.huggingface`).
167
168
169
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
170
171
172
173
174
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
175
176
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
177
        compilation_config: Either an integer or a dictionary. If it is an
178
            integer, it is used as the mode of compilation optimization. If it
179
            is a dictionary, it can specify the full compilation configuration.
180
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
181

182
183
    Note:
        This class is intended to be used for offline inference. For online
184
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
185
    """
186
187
188
189

    def __init__(
        self,
        model: str,
190
        *,
191
192
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
193
        tokenizer: str | None = None,
194
        tokenizer_mode: TokenizerMode = "auto",
195
        skip_tokenizer_init: bool = False,
196
        trust_remote_code: bool = False,
197
        allowed_local_media_path: str = "",
198
        allowed_media_domains: list[str] | None = None,
199
        tensor_parallel_size: int = 1,
200
        dtype: ModelDType = "auto",
201
202
203
204
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
        seed: int | None = None,
205
        gpu_memory_utilization: float = 0.9,
206
        swap_space: float = 4,
207
        cpu_offload_gb: float = 0,
208
        enforce_eager: bool = False,
209
        disable_custom_all_reduce: bool = False,
210
211
212
213
214
215
216
217
218
219
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
220
        **kwargs: Any,
221
    ) -> None:
222
        """LLM constructor."""
223

224
225
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
226

227
228
229
230
231
232
233
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

234
        if "kv_transfer_config" in kwargs and isinstance(
235
236
            kwargs["kv_transfer_config"], dict
        ):
237
            from vllm.config.kv_transfer import KVTransferConfig
238

239
240
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
241
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
242
243
244
245
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
246
247
248
                    raw_config_dict,
                    e,
                )
249
250
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
251
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
252

253
254
255
        if hf_overrides is None:
            hf_overrides = {}

256
        if compilation_config is not None:
257
            if isinstance(compilation_config, int):
258
259
260
                compilation_config_instance = CompilationConfig(
                    mode=CompilationMode(compilation_config)
                )
261
262
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
263
264
265
266
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
267
268
                    }
                )
269
270
            else:
                compilation_config_instance = compilation_config
271
        else:
272
            compilation_config_instance = CompilationConfig()
273

274
275
276
277
278
279
280
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
281
282
                    }
                )
283
284
285
286
287
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

288
        # warn about single-process data parallel usage.
289
290
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
291
292
293
294
295
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
296
            raise ValueError(
297
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
298
299
300
301
302
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
303
        engine_args = EngineArgs(
304
            model=model,
305
306
            runner=runner,
            convert=convert,
307
            tokenizer=tokenizer,
308
            tokenizer_mode=tokenizer_mode,
309
            skip_tokenizer_init=skip_tokenizer_init,
310
            trust_remote_code=trust_remote_code,
311
            allowed_local_media_path=allowed_local_media_path,
312
            allowed_media_domains=allowed_media_domains,
313
314
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
315
            quantization=quantization,
316
            revision=revision,
317
            tokenizer_revision=tokenizer_revision,
318
319
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
320
            kv_cache_memory_bytes=kv_cache_memory_bytes,
321
            swap_space=swap_space,
322
            cpu_offload_gb=cpu_offload_gb,
323
            enforce_eager=enforce_eager,
324
            disable_custom_all_reduce=disable_custom_all_reduce,
325
            hf_token=hf_token,
326
            hf_overrides=hf_overrides,
327
            mm_processor_kwargs=mm_processor_kwargs,
328
            pooler_config=pooler_config,
329
            structured_outputs_config=structured_outputs_instance,
330
            compilation_config=compilation_config_instance,
331
            logits_processors=logits_processors,
332
333
            **kwargs,
        )
334

335
336
        log_non_default_args(engine_args)

337
        self.llm_engine = LLMEngine.from_engine_args(
338
339
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
340
        self.engine_class = type(self.llm_engine)
341

342
        self.request_counter = Counter()
343
        self.default_sampling_params: dict[str, Any] | None = None
344

345
346
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
347
348
        self.supported_tasks = supported_tasks

349
        self.model_config = self.llm_engine.model_config
350
        self.input_processor = self.llm_engine.input_processor
351
        self.io_processor = self.llm_engine.io_processor
352

353
354
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
355

356
    @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
357
    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
358
359
360
361
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
362
            self.llm_engine.tokenizer = tokenizer
363
        else:
364
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
365

366
    def reset_mm_cache(self) -> None:
367
        self.input_processor.clear_mm_cache()
368
369
        self.llm_engine.reset_mm_cache()

370
    def get_default_sampling_params(self) -> SamplingParams:
371
        if self.default_sampling_params is None:
372
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
373
374
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
375
376
        return SamplingParams()

377
378
    def generate(
        self,
379
380
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
381
        *,
382
383
384
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
385
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
386
387
        """Generates the completions for the input prompts.

388
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
389
390
391
392
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
393
            prompts: The prompts to the LLM. You may pass a sequence of prompts
394
                for batch inference. See [PromptType][vllm.inputs.PromptType]
395
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
396
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
397
398
399
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
400
                prompts and it is paired one by one with the prompt.
401
402
403
404
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
405
            lora_request: LoRA request to use for generation, if any.
406
407
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
408
409
410
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
Woosuk Kwon's avatar
Woosuk Kwon committed
411
412

        Returns:
413
            A list of `RequestOutput` objects containing the
414
            generated completions in the same order as the input prompts.
415

416
417
418
419
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
420
        """
421
        model_config = self.model_config
422
423
        runner_type = model_config.runner_type
        if runner_type != "generate":
424
425
426
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
427
428
                "generative model."
            )
429

430
431
        if sampling_params is None:
            # Use default sampling params.
432
            sampling_params = self.get_default_sampling_params()
433

434
        # Add any modality specific loras to the corresponding prompts
435
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
436

437
        self._validate_and_add_requests(
438
            prompts=prompts,
439
            params=sampling_params,
440
            use_tqdm=use_tqdm,
441
            lora_request=lora_request,
442
443
            priority=priority,
        )
444

445
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
446
        return self.engine_class.validate_outputs(outputs, RequestOutput)
447

448
    def _get_modality_specific_lora_reqs(
449
        self,
450
451
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
452
    ):
453
454
455
456
457
458
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
459
460
        if (
            lora_config is None
461
            or not self.model_config.is_multimodal_model
462
463
            or (lora_config and lora_config.default_mm_loras is None)
        ):
464
465
            return lora_request

466
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
467
            prompts = [prompts]
468

469
470
471
472
473
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
474
475
476

        return [
            self._resolve_single_prompt_mm_lora(
477
                prompt,
478
479
                opt_lora_req,
                lora_config.default_mm_loras,
480
481
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
482
483
        ]

484
485
486
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
487
488
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
489
490
491
492
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
493
            or not (mm_data := prompt.get("multi_modal_data") or {})
494
        ):
495
496
            return lora_request

497
498
499
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
500
501
502
503
504
505
506
507
508
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
509
510
511
                " will be skipped",
                intersection,
            )
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
527
528
                    "lora_request as we only apply one LoRARequest per prompt"
                )
529
530
531
532
533
534
535
536
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

537
538
    def collective_rpc(
        self,
539
540
        method: str | Callable[..., _R],
        timeout: float | None = None,
541
        args: tuple = (),
542
        kwargs: dict[str, Any] | None = None,
543
    ) -> list[_R]:
544
545
546
547
548
549
550
551
552
553
554
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
555
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
556
557
558
559
560
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
561

562
563
564
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
565
        """
566
567

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
568
569

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
570
        """
571
572
        Run a function directly on the model inside each worker,
        returning the result for each of them.
573
574
575
576
577
578

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
579
        """
580
        return self.llm_engine.apply_model(func)
581

582
583
    def _get_beam_search_lora_requests(
        self,
584
585
586
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
587
        """Get the optional lora request corresponding to each prompt."""
588
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
589
            raise ValueError(
590
591
                "Lora request list should be the same length as the prompts"
            )
592
593
594
595
596
597

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

598
599
    def beam_search(
        self,
600
        prompts: list[TokensPrompt | TextPrompt],
601
        params: BeamSearchParams,
602
        lora_request: list[LoRARequest] | LoRARequest | None = None,
603
        use_tqdm: bool = False,
604
        concurrency_limit: int | None = None,
605
    ) -> list[BeamSearchOutput]:
606
607
608
609
610
611
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
612
            params: The beam search parameters.
613
            lora_request: LoRA request to use for generation, if any.
614
            use_tqdm: Whether to use tqdm to display the progress bar.
615
616
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
617
        """
618
619
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
620
621
622
623
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
624
625
        length_penalty = params.length_penalty

626
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
627

628
629
630
631
632
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
633

634
635
636
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
637
638
                "Disabling progress bar."
            )
639
640
641
642
643
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

644
645
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
646
647
648
649
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
650
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
651
            return TokensPrompt(**token_prompt_kwargs)
652

653
654
655
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
656
657
658
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width, max_tokens=1, temperature=temperature
        )
659
        instances: list[BeamSearchInstance] = []
660

661
        for lora_req, prompt in zip(lora_requests, prompts):
662
663
664
665
666
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
667
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
668

669
670
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
671
672
673
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
674

675
            instances.append(
676
677
678
679
680
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
681
682
                ),
            )
683

684
        for prompt_start in range(0, len(prompts), concurrency_limit):
685
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
686
687
688

            token_iter = range(max_tokens)
            if use_tqdm:
689
690
691
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
692
693
694
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
695
696
                    "reflect instance-level progress."
                )
697
698
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
699
700
                    sum((instance.beams for instance in instances_batch), [])
                )
701
702
                pos = [0] + list(
                    itertools.accumulate(
703
704
705
                        len(instance.beams) for instance in instances_batch
                    )
                )
706
                instance_start_and_end: list[tuple[int, int]] = list(
707
708
                    zip(pos[:-1], pos[1:])
                )
709
710
711
712
713
714

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
715
716
717
718
719
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
720
721
722

                # only runs for one step
                # we don't need to use tqdm here
723
724
725
726
727
728
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
729

730
731
732
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
733
734
735
736
737
738
739
740
741
742
743
744
745
746
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
747
                                    logprobs=current_beam.logprobs + [logprobs],
748
                                    lora_request=current_beam.lora_request,
749
750
751
752
753
754
755
756
757
758
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
759
760
761
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
762
763
764
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
765
                    instance.beams = sorted_beams[:beam_width]
766
767
768
769

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
770
771
772
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
773
774
775
776
777
778
779
780
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

781
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
782
        self,
783
784
785
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
786
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
787
        add_generation_prompt: bool = True,
788
        continue_final_message: bool = False,
789
790
791
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
792
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
793
        """
794
795
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
796

797
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
798
        Returns:
799
800
801
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
802
        """
803
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
804

805
806
        # Handle multi and single conversations
        if is_list_of(messages, list):
807
            # messages is list[list[...]]
808
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
809
        else:
810
            # messages is list[...]
811
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
812

813
        tokenizer = self.get_tokenizer()
814
        model_config = self.model_config
815
816
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
817
            tools,
818
819
            chat_template_content_format,
            tokenizer,
820
            model_config=model_config,
821
822
        )

823
824
825
826
827
828
829
830
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

831
        prompts: list[TokensPrompt] = []
832
833

        for msgs in list_of_messages:
834
835
836
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
837
            conversation, mm_data, mm_uuids = parse_chat_messages(
838
839
840
841
842
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
843
844

            if isinstance(tokenizer, MistralTokenizer):
845
                prompt_token_ids = apply_mistral_chat_template(
846
847
                    tokenizer,
                    messages=msgs,
848
                    **_chat_template_kwargs,
849
850
                )
            else:
851
                prompt_str = apply_hf_chat_template(
852
                    tokenizer=tokenizer,
853
                    conversation=conversation,
854
                    model_config=model_config,
855
                    **_chat_template_kwargs,
856
                )
857
858
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
859
860
861
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
862

863
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
864
865
866
867

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

868
869
870
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

871
872
873
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

874
            prompts.append(prompt)
875

876
877
878
879
        return prompts

    def chat(
        self,
880
881
882
883
884
885
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
886
887
888
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
889
890
891
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
955
        return self.generate(
956
            prompts,
957
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
958
959
960
961
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

962
963
    def encode(
        self,
964
965
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
966
        *,
967
968
969
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
970
        pooling_task: PoolingTask | None = None,
971
        tokenization_kwargs: dict[str, Any] | None = None,
972
    ) -> list[PoolingRequestOutput]:
973
974
        """Apply pooling to the hidden states corresponding to the input
        prompts.
975

976
        This class automatically batches the given prompts, considering
977
978
979
980
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
981
            prompts: The prompts to the LLM. You may pass a sequence of prompts
982
                for batch inference. See [PromptType][vllm.inputs.PromptType]
983
                for more details about the format of each prompt.
984
985
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
986
987
988
989
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
990
            lora_request: LoRA request to use for generation, if any.
991
            pooling_task: Override the pooling task to use.
992
993
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
994
995

        Returns:
996
            A list of `PoolingRequestOutput` objects containing the
997
            pooled hidden states in the same order as the input prompts.
998

999
1000
1001
1002
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1003
        """
1004

1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
        error_str = (
            "pooling_task required for `LLM.encode`\n"
            "Please use one of the more specific methods or set the "
            "pooling_task when using `LLM.encode`:\n"
            "  - For embeddings, use `LLM.embed(...)` "
            'or `pooling_task="embed"`.\n'
            "  - For classification logits, use `LLM.classify(...)` "
            'or `pooling_task="classify"`.\n'
            "  - For similarity scores, use `LLM.score(...)`.\n"
            "  - For rewards, use `LLM.reward(...)` "
            'or `pooling_task="token_classify"`\n'
            "  - For token classification, "
            'use `pooling_task="token_classify"`\n'
            '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
        )
1020

1021
        if pooling_task is None:
1022
            raise ValueError(error_str)
1023

1024
        model_config = self.model_config
1025
        runner_type = model_config.runner_type
1026
        if runner_type != "pooling":
1027
1028
1029
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1030
1031
                "pooling model."
            )
1032

1033
1034
1035
1036
1037
1038
1039
1040
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1041
1042
                    "offline inference example for more details."
                )
1043
1044
1045
1046
1047
1048
1049

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens

1078
        self._validate_and_add_requests(
1079
            prompts=prompts,
1080
            params=pooling_params,
1081
            use_tqdm=use_tqdm,
1082
            lora_request=lora_request,
1083
1084
        )

1085
        outputs = self._run_engine(use_tqdm=use_tqdm)
1086
1087

        model_outputs = self.engine_class.validate_outputs(
1088
1089
            outputs, PoolingRequestOutput
        )
1090
1091
1092
1093
1094

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1095
1096
                model_output=model_outputs
            )
1097
1098

            return [
1099
1100
1101
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1102
1103
1104
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1105
1106
1107
                    prompt_token_ids=[],
                    finished=True,
                )
1108
1109
1110
            ]
        else:
            return model_outputs
1111

1112
1113
    def embed(
        self,
1114
        prompts: PromptType | Sequence[PromptType],
1115
        *,
1116
1117
1118
1119
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1120
    ) -> list[EmbeddingRequestOutput]:
1121
1122
1123
1124
1125
1126
1127
1128
1129
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1130
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1131
                for more details about the format of each prompt.
1132
1133
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1134
1135
1136
1137
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1138
1139
1140
            lora_request: LoRA request to use for generation, if any.

        Returns:
1141
            A list of `EmbeddingRequestOutput` objects containing the
1142
1143
            embedding vectors in the same order as the input prompts.
        """
1144
        if "embed" not in self.supported_tasks:
1145
1146
            raise ValueError(
                "Embedding API is not supported by this model. "
1147
1148
                "Try converting the model using `--convert embed`."
            )
1149

1150
1151
1152
1153
1154
1155
1156
1157
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1158
1159
1160
1161
1162

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1163
        prompts: PromptType | Sequence[PromptType],
1164
        *,
1165
1166
1167
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1168
    ) -> list[ClassificationRequestOutput]:
1169
1170
1171
1172
1173
1174
1175
1176
1177
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1178
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1179
                for more details about the format of each prompt.
1180
1181
1182
1183
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1184
            lora_request: LoRA request to use for generation, if any.
1185
1186
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1187
        Returns:
1188
            A list of `ClassificationRequestOutput` objects containing the
1189
1190
            embedding vectors in the same order as the input prompts.
        """
1191
        if "classify" not in self.supported_tasks:
1192
            raise ValueError(
1193
                "Classification API is not supported by this model. "
1194
1195
                "Try converting the model using `--convert classify`."
            )
1196

1197
1198
1199
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1200
            pooling_params=pooling_params,
1201
1202
1203
            lora_request=lora_request,
            pooling_task="classify",
        )
1204
1205
1206

        return [ClassificationRequestOutput.from_base(item) for item in items]

1207
1208
    def reward(
        self,
1209
        prompts: PromptType | Sequence[PromptType],
1210
1211
        /,
        *,
1212
1213
1214
1215
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1216
1217
1218
1219
1220
1221
1222
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1223
                for more details about the format of each prompt.
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1242
            pooling_task="token_classify",
1243
1244
        )

1245
1246
1247
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1248
1249
1250
1251
1252
1253
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1254
1255
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1256
            text_1 + text_2,
1257
            truncate_prompt_tokens=truncate_prompt_tokens,
1258
1259
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1260
            pooling_params=pooling_params,
1261
1262
            pooling_task="embed",
        )
1263

1264
1265
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1266
1267
1268
1269

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1270
1271
1272
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1273

1274
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1275
1276
1277
1278
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1279
        tokenizer: AnyTokenizer,
1280
1281
1282
1283
1284
1285
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1286
    ) -> list[ScoringRequestOutput]:
1287
        model_config = self.model_config
1288
1289

        if isinstance(tokenizer, MistralTokenizer):
1290
            raise ValueError("Score API is not supported for Mistral tokenizer")
1291

1292
1293
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1294

1295
1296
1297
1298
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1299
        pooling_params_list = list[PoolingParams]()
1300

1301
        tokenization_kwargs: dict[str, Any] = {}
1302

1303
1304
1305
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1306

1307
        prompts = list[PromptType]()
1308

1309
1310
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1311
1312
1313
1314
1315
1316
1317
1318
1319
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1320
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1321
1322
1323
1324
1325
1326
1327
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1328
            prompts.append(engine_prompt)
1329
1330

        self._validate_and_add_requests(
1331
            prompts=prompts,
1332
            params=pooling_params_list,
1333
            use_tqdm=use_tqdm,
1334
1335
1336
1337
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1338
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1339
1340
1341

        return [ScoringRequestOutput.from_base(item) for item in items]

1342
1343
    def score(
        self,
1344
1345
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1346
        /,
1347
        *,
1348
1349
1350
1351
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1352
    ) -> list[ScoringRequestOutput]:
1353
1354
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1355

1356
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1357
1358
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1359
        The input pairs are used to build a list of prompts for the
1360
1361
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1362
1363
1364
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1365
        appropriate multi-modal models. For multi-modal inputs, ensure the
1366
        prompt structure matches the model's expected input format.
1367
1368

        Args:
1369
1370
1371
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1372
                the `data_2` list.
1373
            data_2: The data to pair with the query to form the input to
1374
                the LLM. Can be text or multi-modal data. See [PromptType]
1375
                [vllm.inputs.PromptType] for more details about the format of
1376
                each prompt.
1377
1378
1379
1380
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1381
            lora_request: LoRA request to use for generation, if any.
1382
1383
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1384
        Returns:
1385
            A list of `ScoringRequestOutput` objects containing the
1386
1387
            generated scores in the same order as the input prompts.
        """
1388
        model_config = self.model_config
1389
        runner_type = model_config.runner_type
1390
        if runner_type != "pooling":
1391
1392
1393
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1394
1395
                "pooling model."
            )
1396

1397
1398
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1399
1400
1401
1402
1403
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1404

1405
1406
1407
1408
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1409
            raise ValueError("Score API is only enabled for num_labels == 1.")
1410
1411
1412
1413

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1414
        tokenizer = self.get_tokenizer()
1415

1416
        if not model_config.is_multimodal_model:
1417

1418
            def check_data_type(
1419
1420
1421
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1422
            ):
1423
                if isinstance(data, dict) and "content" in data:
1424
1425
1426
1427
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1428
1429
1430
1431
1432
1433
1434

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1435
1436
1437
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1438
1439
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1440
1441
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1470

1471
        if model_config.is_cross_encoder:
1472
1473
1474
1475
1476
1477
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1478
                pooling_params,
1479
1480
                lora_request,
            )
1481
        else:
1482
1483
            return self._embedding_score(
                tokenizer,
1484
1485
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1486
1487
                truncate_prompt_tokens,
                use_tqdm,
1488
                pooling_params,
1489
1490
                lora_request,
            )
1491

1492
1493
1494
1495
1496
1497
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1498
1499
    def reset_prefix_cache(self) -> None:
        self.llm_engine.reset_prefix_cache()
1500

1501
1502
1503
1504
1505
1506
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1507
        Args:
1508
1509
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1510
                is forgotten. Level 1 sleep is good for sleeping and waking
1511
1512
1513
1514
1515
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1516
                sleep is good for sleeping and waking up the engine to run a
1517
                different model or update the model, where previous model
1518
                weights are not needed. It reduces CPU memory pressure.
1519
        """
1520
        self.reset_prefix_cache()
1521
1522
        self.llm_engine.sleep(level=level)

1523
    def wake_up(self, tags: list[str] | None = None):
1524
        """
1525
1526
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1527

1528
        Args:
1529
1530
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1531
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1532
                wake_up should be called with all tags (or None) before the
1533
1534
1535
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1536

1537
1538
1539
1540
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1541
            A `MetricSnapshot` instance capturing the current state
1542
1543
1544
1545
1546
1547
1548
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1549
1550
    def _validate_and_add_requests(
        self,
1551
1552
1553
1554
1555
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1556
        *,
1557
1558
1559
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1560
    ) -> None:
1561
        if isinstance(prompts, (str, dict)):
1562
            # Convert a single prompt to a list.
1563
            prompts = [prompts]  # type: ignore[list-item]
1564

1565
        num_requests = len(prompts)
1566
        if isinstance(params, Sequence) and len(params) != num_requests:
1567
1568
1569
1570
1571
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )
1572
1573
1574
1575
1576
1577
        if priority is not None and len(priority) != num_requests:
            raise ValueError(
                "The lengths of prompts "
                f"({num_requests}) and priority ({len(priority)}) "
                "must be the same."
            )
1578
1579

        for sp in params if isinstance(params, Sequence) else (params,):
1580
1581
1582
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1583

Zhuohan Li's avatar
Zhuohan Li committed
1584
        # Add requests to the engine.
1585
1586
        it = prompts
        if use_tqdm:
1587
1588
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1589

1590
        added_request_ids: list[str] = []
1591

1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
        try:
            for i, prompt in enumerate(it):
                if isinstance(prompt, dict):
                    self._validate_mm_data_and_uuids(
                        prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                    )
                request_id = self._add_request(
                    prompt,
                    params[i] if isinstance(params, Sequence) else params,
                    lora_request=lora_request[i]
                    if isinstance(lora_request, Sequence)
                    else lora_request,
                    priority=priority[i] if priority else 0,
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
                self.llm_engine.abort_request(added_request_ids)
            raise e
1611

1612
    def _validate_mm_data_and_uuids(
1613
        self,
1614
1615
        multi_modal_data: Any | None,  # MultiModalDataDict
        multi_modal_uuids: Any | None,  # MultiModalUUIDDict
1616
1617
1618
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1619
        then its corresponding UUID must be set.
1620
1621
1622
1623
1624
1625
1626
1627
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1628
1629
1630
1631
1632
1633
1634
1635
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1636
1637
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1638
1639
                                f"but UUID is not provided"
                            )
1640
                        else:
1641
1642
1643
1644
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1645
1646
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1647
1648
                                    f"but UUID is not provided"
                                )
1649
            else:
1650
1651
1652
1653
1654
1655
1656
1657
1658
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1659

1660
1661
1662
1663
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1664
        params: SamplingParams | PoolingParams,
1665
        *,
1666
        lora_request: LoRARequest | None,
1667
1668
1669
1670
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
        tokenization_kwargs: dict[str, Any] = {}
1671
1672
1673
1674
1675
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1676

1677
        engine_request = self.input_processor.process_inputs(
1678
1679
1680
1681
1682
1683
1684
1685
1686
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1687
    def _add_request(
nunjunj's avatar
nunjunj committed
1688
        self,
1689
        prompt: PromptType,
1690
1691
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1692
        priority: int = 0,
1693
    ) -> str:
1694
        prompt_text, _, _ = get_prompt_components(prompt)
1695
        request_id = str(next(self.request_counter))
1696
1697

        engine_request, tokenization_kwargs = self._process_inputs(
1698
            request_id,
1699
            prompt,
1700
1701
            params,
            lora_request=lora_request,
1702
1703
1704
1705
1706
1707
1708
1709
            priority=priority,
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1710
            tokenization_kwargs=tokenization_kwargs,
1711
            priority=priority,
1712
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1713
        )
1714
        return request_id
1715

1716
    def _run_engine(
1717
1718
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1719
1720
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1721
            num_requests = self.llm_engine.get_num_unfinished_requests()
1722
1723
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1724
1725
1726
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1727
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1728
            )
1729

Zhuohan Li's avatar
Zhuohan Li committed
1730
        # Run the engine.
1731
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1732
1733
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1734
1735
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1736
            for output in step_outputs:
1737
                if output.finished:
1738
1739
                    outputs.append(output)
                    if use_tqdm:
1740
1741
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1742
                            n = len(output.outputs)
1743
                            assert output.prompt_token_ids is not None
1744
                            total_in_toks += len(output.prompt_token_ids) * n
1745
1746
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1747
1748
1749
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1750
1751
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1752
1753
                                f"output: {out_spd:.2f} toks/s"
                            )
1754
                            pbar.update(n)
1755
1756
                        else:
                            pbar.update(1)
1757
1758
                        if pbar.n == num_requests:
                            pbar.refresh()
1759

1760
1761
        if use_tqdm:
            pbar.close()
1762
1763
1764
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1765
        return sorted(outputs, key=lambda x: int(x.request_id))