llm.py 71.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
    CompilationConfig,
    ModelDType,
    StructuredOutputsConfig,
    TokenizerMode,
    is_init_field,
)
from vllm.engine.arg_utils import (
    ConvertOption,
    EngineArgs,
    HfOverrides,
    PoolerConfig,
    RunnerOption,
)
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
58
from vllm.inputs.parse import get_prompt_components
59
from vllm.logger import init_logger
60
from vllm.lora.request import LoRARequest
61
from vllm.model_executor.layers.quantization import QuantizationMethods
62
63
64
65
66
67
68
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
69
from vllm.plugins.io_processors import get_io_processor
70
from vllm.pooling_params import PoolingParams
71
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
72
from vllm.tasks import PoolingTask
73
74
75
76
77
from vllm.transformers_utils.tokenizer import (
    AnyTokenizer,
    MistralTokenizer,
    get_cached_tokenizer,
)
yhu422's avatar
yhu422 committed
78
from vllm.usage.usage_lib import UsageContext
79
from vllm.utils import Counter, Device, as_iter, is_list_of
80
from vllm.v1.engine import EngineCoreRequest
81
from vllm.v1.engine.llm_engine import LLMEngine
82
from vllm.v1.engine.processor import Processor
83
from vllm.v1.sample.logits_processor import LogitsProcessor
84

85
86
87
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

88
89
logger = init_logger(__name__)

90
91
_R = TypeVar("_R", default=Any)

92
93

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
94
95
96
97
98
99
100
101
102
103
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
104
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
105
106
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
107
108
109
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
110
111
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
112
113
114
115
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
116
        allowed_media_domains: If set, only media URLs that belong to this
117
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
118
119
120
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
124
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
125
        quantization: The method used to quantize the model weights. Currently,
126
            we support "awq", "gptq", and "fp8" (experimental).
127
128
129
130
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
131
132
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
133
134
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
135
136
137
138
139
140
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
141
142
143
144
145
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
146
            compared with using gpu_memory_utilization. Note that
147
148
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
149
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
150
151
152
153
154
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
155
156
157
158
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
159
160
161
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
162
163
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
164
        hf_token: The token to use as HTTP bearer authorization for remote files
165
            . If `True`, will use the token generated when running
166
            `huggingface-cli login` (stored in `~/.huggingface`).
167
168
169
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
170
171
172
173
174
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
175
176
177
178
179
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
180
181
182
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
183
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
184

185
186
    Note:
        This class is intended to be used for offline inference. For online
187
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
188
    """
189
190
191
192

    def __init__(
        self,
        model: str,
193
        *,
194
195
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
196
        tokenizer: Optional[str] = None,
197
        tokenizer_mode: TokenizerMode = "auto",
198
        skip_tokenizer_init: bool = False,
199
        trust_remote_code: bool = False,
200
        allowed_local_media_path: str = "",
201
        allowed_media_domains: Optional[list[str]] = None,
202
        tensor_parallel_size: int = 1,
203
204
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
205
        revision: Optional[str] = None,
206
        tokenizer_revision: Optional[str] = None,
207
        seed: Optional[int] = None,
208
        gpu_memory_utilization: float = 0.9,
209
        swap_space: float = 4,
210
        cpu_offload_gb: float = 0,
211
        enforce_eager: bool = False,
212
        disable_custom_all_reduce: bool = False,
213
        hf_token: Optional[Union[bool, str]] = None,
214
        hf_overrides: Optional[HfOverrides] = None,
215
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
216
        pooler_config: Optional[PoolerConfig] = None,
217
        override_pooler_config: Optional[PoolerConfig] = None,
218
219
220
        structured_outputs_config: Optional[
            Union[dict[str, Any], StructuredOutputsConfig]
        ] = None,
221
        kv_cache_memory_bytes: Optional[int] = None,
222
223
224
225
        compilation_config: Optional[
            Union[int, dict[str, Any], CompilationConfig]
        ] = None,
        logits_processors: Optional[list[Union[str, type[LogitsProcessor]]]] = None,
226
        **kwargs: Any,
227
    ) -> None:
228
        """LLM constructor."""
229

230
231
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
232

233
234
235
236
237
238
239
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

240
        if "kv_transfer_config" in kwargs and isinstance(
241
242
            kwargs["kv_transfer_config"], dict
        ):
243
            from vllm.config.kv_transfer import KVTransferConfig
244

245
246
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
247
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
248
249
250
251
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
252
253
254
                    raw_config_dict,
                    e,
                )
255
256
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
257
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
258

259
260
261
        if hf_overrides is None:
            hf_overrides = {}

262
        if compilation_config is not None:
263
264
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
265
266
                    level=compilation_config
                )
267
268
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
269
270
271
272
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
273
274
                    }
                )
275
276
            else:
                compilation_config_instance = compilation_config
277
        else:
278
            compilation_config_instance = CompilationConfig()
279

280
281
282
283
284
285
286
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
287
288
                    }
                )
289
290
291
292
293
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

Zhuohan Li's avatar
Zhuohan Li committed
294
        engine_args = EngineArgs(
295
            model=model,
296
297
            runner=runner,
            convert=convert,
298
            tokenizer=tokenizer,
299
            tokenizer_mode=tokenizer_mode,
300
            skip_tokenizer_init=skip_tokenizer_init,
301
            trust_remote_code=trust_remote_code,
302
            allowed_local_media_path=allowed_local_media_path,
303
            allowed_media_domains=allowed_media_domains,
304
305
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
306
            quantization=quantization,
307
            revision=revision,
308
            tokenizer_revision=tokenizer_revision,
309
310
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
311
            kv_cache_memory_bytes=kv_cache_memory_bytes,
312
            swap_space=swap_space,
313
            cpu_offload_gb=cpu_offload_gb,
314
            enforce_eager=enforce_eager,
315
            disable_custom_all_reduce=disable_custom_all_reduce,
316
            hf_token=hf_token,
317
            hf_overrides=hf_overrides,
318
            mm_processor_kwargs=mm_processor_kwargs,
319
            pooler_config=pooler_config,
320
            override_pooler_config=override_pooler_config,
321
            structured_outputs_config=structured_outputs_instance,
322
            compilation_config=compilation_config_instance,
323
            logits_processors=logits_processors,
324
325
            **kwargs,
        )
326

327
328
        log_non_default_args(engine_args)

329
330
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
331
332
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
333
        self.engine_class = type(self.llm_engine)
334

335
        self.request_counter = Counter()
336
        self.default_sampling_params: Union[dict[str, Any], None] = None
337

338
        supported_tasks = self.llm_engine.get_supported_tasks()  # type: ignore
339
340
341
342
343

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

344
345
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
346
347
348
        self.io_processor = get_io_processor(
            self.llm_engine.vllm_config, io_processor_plugin
        )
349

350
351
352
353
    @property
    def model_config(self):
        return self.llm_engine.model_config

354
355
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
356

357
    @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
358
    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
359
360
361
362
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
363
            self.llm_engine.tokenizer = tokenizer
364
        else:
365
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
366

367
368
369
    def _get_processor(self) -> Processor:
        if not hasattr(self, "_processor"):
            vllm_config = self.llm_engine.vllm_config
370
371
            self._processor = Processor(vllm_config)

372
373
        return self._processor

374
    def get_default_sampling_params(self) -> SamplingParams:
375
376
        if self.default_sampling_params is None:
            self.default_sampling_params = (
377
378
                self.llm_engine.model_config.get_diff_sampling_param()
            )
379
380
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
381
382
        return SamplingParams()

383
384
385
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
386
387
388
        sampling_params: Optional[
            Union[SamplingParams, Sequence[SamplingParams]]
        ] = None,
389
        *,
390
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
391
392
393
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
394
395
        """Generates the completions for the input prompts.

396
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
397
398
399
400
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
401
            prompts: The prompts to the LLM. You may pass a sequence of prompts
402
                for batch inference. See [PromptType][vllm.inputs.PromptType]
403
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
404
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
405
406
407
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
408
                prompts and it is paired one by one with the prompt.
409
410
411
412
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
413
            lora_request: LoRA request to use for generation, if any.
414
415
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
416
417

        Returns:
418
            A list of `RequestOutput` objects containing the
419
            generated completions in the same order as the input prompts.
420

421
422
423
424
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
425
        """
426
427
428
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
429
430
431
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
432
433
                "generative model."
            )
434

435
436
        if sampling_params is None:
            # Use default sampling params.
437
            sampling_params = self.get_default_sampling_params()
438

439
        # Add any modality specific loras to the corresponding prompts
440
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
441

442
        self._validate_and_add_requests(
443
            prompts=prompts,
444
            params=sampling_params,
445
            use_tqdm=use_tqdm,
446
            lora_request=lora_request,
447
448
            priority=priority,
        )
449

450
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
451
        return self.engine_class.validate_outputs(outputs, RequestOutput)
452

453
    def _get_modality_specific_lora_reqs(
454
455
456
457
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
    ):
458
459
460
461
462
463
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
464
465
466
467
468
        if (
            lora_config is None
            or not self.llm_engine.model_config.is_multimodal_model
            or (lora_config and lora_config.default_mm_loras is None)
        ):
469
470
            return lora_request

471
472
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
473

474
475
476
477
478
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
479
480
481

        return [
            self._resolve_single_prompt_mm_lora(
482
                prompt,
483
484
                opt_lora_req,
                lora_config.default_mm_loras,
485
486
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
487
488
        ]

489
490
491
492
493
494
495
496
497
498
499
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
        lora_request: Optional[LoRARequest],
        default_mm_loras: Optional[dict[str, str]],
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
            or "multi_modal_data" not in prompt
        ):
500
501
            return lora_request

502
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
503

504
505
506
        intersection = set(prompt["multi_modal_data"].keys()).intersection(
            default_mm_loras.keys()
        )
507
508
509
510
511
512
513
514
515
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
516
517
518
                " will be skipped",
                intersection,
            )
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
534
535
                    "lora_request as we only apply one LoRARequest per prompt"
                )
536
537
538
539
540
541
542
543
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

544
545
546
547
548
549
550
    def collective_rpc(
        self,
        method: Union[str, Callable[..., _R]],
        timeout: Optional[float] = None,
        args: tuple = (),
        kwargs: Optional[dict[str, Any]] = None,
    ) -> list[_R]:
551
552
553
554
555
556
557
558
559
560
561
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
562
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
563
564
565
566
567
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
568

569
570
571
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
572
        """
573
574

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
575
576

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
577
        """
578
579
        Run a function directly on the model inside each worker,
        returning the result for each of them.
580
581
582
583
584
585

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
586
        """
587
        return self.llm_engine.apply_model(func)
588

589
590
591
592
593
594
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
595
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
596
            raise ValueError(
597
598
                "Lora request list should be the same length as the prompts"
            )
599
600
601
602
603
604

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

605
606
    def beam_search(
        self,
607
        prompts: list[Union[TokensPrompt, TextPrompt]],
608
        params: BeamSearchParams,
609
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
610
        use_tqdm: bool = False,
611
        concurrency_limit: Optional[int] = None,
612
    ) -> list[BeamSearchOutput]:
613
614
615
616
617
618
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
619
            params: The beam search parameters.
620
            lora_request: LoRA request to use for generation, if any.
621
            use_tqdm: Whether to use tqdm to display the progress bar.
622
623
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
624
        """
625
626
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
627
628
629
630
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
631
632
        length_penalty = params.length_penalty

633
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
634

635
636
637
638
639
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
640

641
642
643
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
644
645
                "Disabling progress bar."
            )
646
647
648
649
650
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

651
652
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
653
654
655
656
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
657
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
658
            return TokensPrompt(**token_prompt_kwargs)
659

660
661
662
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
663
664
665
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width, max_tokens=1, temperature=temperature
        )
666
        instances: list[BeamSearchInstance] = []
667

668
        for lora_req, prompt in zip(lora_requests, prompts):
669
670
671
672
673
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
674
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
675

676
677
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
678
679
680
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
681

682
            instances.append(
683
684
685
686
687
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
688
689
                ),
            )
690

691
        for prompt_start in range(0, len(prompts), concurrency_limit):
692
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
693
694
695

            token_iter = range(max_tokens)
            if use_tqdm:
696
697
698
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
699
700
701
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
702
703
                    "reflect instance-level progress."
                )
704
705
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
706
707
                    sum((instance.beams for instance in instances_batch), [])
                )
708
709
                pos = [0] + list(
                    itertools.accumulate(
710
711
712
                        len(instance.beams) for instance in instances_batch
                    )
                )
713
                instance_start_and_end: list[tuple[int, int]] = list(
714
715
                    zip(pos[:-1], pos[1:])
                )
716
717
718
719
720
721

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
722
723
724
725
726
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
727
728
729

                # only runs for one step
                # we don't need to use tqdm here
730
731
732
733
734
735
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
736

737
738
739
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
740
741
742
743
744
745
746
747
748
749
750
751
752
753
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
754
                                    logprobs=current_beam.logprobs + [logprobs],
755
                                    lora_request=current_beam.lora_request,
756
757
758
759
760
761
762
763
764
765
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
766
767
768
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
769
770
771
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
772
                    instance.beams = sorted_beams[:beam_width]
773
774
775
776

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
777
778
779
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
780
781
782
783
784
785
786
787
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

788
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
789
        self,
790
791
792
        messages: Union[
            list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]]
        ],
nunjunj's avatar
nunjunj committed
793
        chat_template: Optional[str] = None,
794
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
795
        add_generation_prompt: bool = True,
796
        continue_final_message: bool = False,
797
        tools: Optional[list[dict[str, Any]]] = None,
798
        chat_template_kwargs: Optional[dict[str, Any]] = None,
799
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
800
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
801
        """
802
803
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
804

805
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
806
        Returns:
807
808
809
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
810
        """
811
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
812

813
814
        # Handle multi and single conversations
        if is_list_of(messages, list):
815
            # messages is list[list[...]]
816
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
817
        else:
818
            # messages is list[...]
819
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
820

821
        tokenizer = self.get_tokenizer()
822
823
824
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
825
            tools,
826
827
            chat_template_content_format,
            tokenizer,
828
            model_config=model_config,
829
830
        )

831
832
833
834
835
836
837
838
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

839
        prompts: list[TokensPrompt] = []
840
841

        for msgs in list_of_messages:
842
843
844
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
845
            conversation, mm_data, mm_uuids = parse_chat_messages(
846
847
848
849
850
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
851
852

            if isinstance(tokenizer, MistralTokenizer):
853
                prompt_token_ids = apply_mistral_chat_template(
854
855
                    tokenizer,
                    messages=msgs,
856
                    **_chat_template_kwargs,
857
858
                )
            else:
859
                prompt_str = apply_hf_chat_template(
860
                    tokenizer=tokenizer,
861
                    conversation=conversation,
862
                    model_config=model_config,
863
                    **_chat_template_kwargs,
864
                )
865
866
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
867
868
869
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
870

871
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
872
873
874
875

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

876
877
878
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

879
880
881
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

882
            prompts.append(prompt)
883

884
885
886
887
        return prompts

    def chat(
        self,
888
889
890
891
        messages: Union[
            list[ChatCompletionMessageParam], list[list[ChatCompletionMessageParam]]
        ],
        sampling_params: Optional[Union[SamplingParams, list[SamplingParams]]] = None,
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
964
        return self.generate(
965
            prompts,
966
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
967
968
969
970
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

971
972
    def encode(
        self,
973
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
974
        pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None,
975
        *,
976
        truncate_prompt_tokens: Optional[int] = None,
977
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
978
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
979
        pooling_task: PoolingTask = "encode",
980
        tokenization_kwargs: Optional[dict[str, Any]] = None,
981
    ) -> list[PoolingRequestOutput]:
982
983
        """Apply pooling to the hidden states corresponding to the input
        prompts.
984

985
        This class automatically batches the given prompts, considering
986
987
988
989
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
990
            prompts: The prompts to the LLM. You may pass a sequence of prompts
991
                for batch inference. See [PromptType][vllm.inputs.PromptType]
992
                for more details about the format of each prompt.
993
994
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
995
996
997
998
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
999
            lora_request: LoRA request to use for generation, if any.
1000
            pooling_task: Override the pooling task to use.
1001
1002
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
1003
1004

        Returns:
1005
            A list of `PoolingRequestOutput` objects containing the
1006
            pooled hidden states in the same order as the input prompts.
1007

1008
1009
1010
1011
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1012
        """
1013
1014
1015
1016

        if self.supported_tasks == ["encode"] and pooling_task is None:
            pooling_task = "encode"

1017
        if pooling_task is None:
1018
            pooling_task = "embed" if "embed" in self.supported_tasks else "encode"
1019
1020
1021
1022
1023
1024

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
1025
                'or `pooling_task="embed"`.\n'
1026
                "  - For classification logits, use `LLM.classify(...)` "
1027
                'or `pooling_task="classify"`.\n'
1028
                "  - For rewards, use `LLM.reward(...)` "
1029
                'or `pooling_task="reward"`\n'
1030
                "  - For similarity scores, use `LLM.score(...)`.",
1031
1032
                pooling_task,
            )
1033

1034
1035
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1036
        if runner_type != "pooling":
1037
1038
1039
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1040
1041
                "pooling model."
            )
1042

1043
        if pooling_task not in self.supported_tasks:
1044
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
1045

1046
1047
1048
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1049

1050
1051
1052
1053
1054
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
1055

1056
1057
1058
1059
1060
1061
1062
1063
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1064
1065
                    "offline inference example for more details."
                )
1066
1067
1068
1069
1070
1071
1072

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1073
        self._validate_and_add_requests(
1074
            prompts=prompts,
1075
            params=pooling_params,
1076
            use_tqdm=use_tqdm,
1077
            lora_request=lora_request,
1078
1079
        )

1080
        outputs = self._run_engine(use_tqdm=use_tqdm)
1081
1082

        model_outputs = self.engine_class.validate_outputs(
1083
1084
            outputs, PoolingRequestOutput
        )
1085
1086
1087
1088
1089

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1090
1091
                model_output=model_outputs
            )
1092
1093

            return [
1094
1095
1096
1097
1098
1099
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
                    prompt_token_ids=[],
                    finished=True,
                )
1100
1101
1102
            ]
        else:
            return model_outputs
1103

1104
1105
1106
1107
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1108
        truncate_prompt_tokens: Optional[int] = None,
1109
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1110
        pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None,
1111
1112
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1113
1114
1115
1116
1117
1118
1119
1120
1121
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1122
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1123
                for more details about the format of each prompt.
1124
1125
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1126
1127
1128
1129
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1130
1131
1132
            lora_request: LoRA request to use for generation, if any.

        Returns:
1133
            A list of `EmbeddingRequestOutput` objects containing the
1134
1135
            embedding vectors in the same order as the input prompts.
        """
1136
        if "embed" not in self.supported_tasks:
1137
1138
            raise ValueError(
                "Embedding API is not supported by this model. "
1139
1140
                "Try converting the model using `--convert embed`."
            )
1141

1142
1143
1144
1145
1146
1147
1148
1149
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1150
1151
1152
1153
1154
1155
1156

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1157
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1158
        pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None,
1159
1160
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1161
1162
1163
1164
1165
1166
1167
1168
1169
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1170
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1171
                for more details about the format of each prompt.
1172
1173
1174
1175
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1176
            lora_request: LoRA request to use for generation, if any.
1177
1178
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1179
        Returns:
1180
            A list of `ClassificationRequestOutput` objects containing the
1181
1182
            embedding vectors in the same order as the input prompts.
        """
1183
        if "classify" not in self.supported_tasks:
1184
            raise ValueError(
1185
                "Classification API is not supported by this model. "
1186
1187
                "Try converting the model using `--convert classify`."
            )
1188

1189
1190
1191
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1192
            pooling_params=pooling_params,
1193
1194
1195
            lora_request=lora_request,
            pooling_task="classify",
        )
1196
1197
1198

        return [ClassificationRequestOutput.from_base(item) for item in items]

1199
1200
1201
1202
1203
1204
1205
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1206
        pooling_params: Optional[Union[PoolingParams, Sequence[PoolingParams]]] = None,
1207
1208
1209
1210
1211
1212
1213
1214
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1215
                for more details about the format of each prompt.
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1237
1238
1239
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1240
1241
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1242
        truncate_prompt_tokens: Optional[int] = None,
1243
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1244
        pooling_params: Optional[PoolingParams] = None,
1245
1246
1247
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1248
            text_1 + text_2,
1249
            truncate_prompt_tokens=truncate_prompt_tokens,
1250
1251
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1252
            pooling_params=pooling_params,
1253
1254
            pooling_task="embed",
        )
1255

1256
1257
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1258
1259
1260
1261

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1262
1263
1264
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1265

1266
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1267
1268
1269
1270
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1271
        tokenizer: AnyTokenizer,
1272
1273
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1274
        truncate_prompt_tokens: Optional[int] = None,
1275
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1276
        pooling_params: Optional[PoolingParams] = None,
1277
1278
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1279
        model_config = self.llm_engine.model_config
1280
1281

        if isinstance(tokenizer, MistralTokenizer):
1282
            raise ValueError("Score API is not supported for Mistral tokenizer")
1283

1284
1285
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1286

1287
1288
1289
1290
1291
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1292
        pooling_params_list = list[PoolingParams]()
1293

1294
        tokenization_kwargs: dict[str, Any] = {}
1295

1296
1297
1298
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1299

1300
        prompts = list[PromptType]()
1301

1302
1303
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1304
        model_config = self.llm_engine.model_config
1305

1306
1307
1308
1309
1310
1311
1312
1313
1314
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1315
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1316
1317
1318
1319
1320
1321
1322
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1323
            prompts.append(engine_prompt)
1324
1325

        self._validate_and_add_requests(
1326
            prompts=prompts,
1327
            params=pooling_params_list,
1328
            use_tqdm=use_tqdm,
1329
1330
1331
1332
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1333
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1334
1335
1336

        return [ScoringRequestOutput.from_base(item) for item in items]

1337
1338
    def score(
        self,
1339
1340
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam],
1341
        /,
1342
        *,
1343
        truncate_prompt_tokens: Optional[int] = None,
1344
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1345
        pooling_params: Optional[PoolingParams] = None,
1346
1347
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1348
1349
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1350

1351
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1352
1353
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1354
        The input pairs are used to build a list of prompts for the
1355
1356
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1357
1358
1359
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1360
        appropriate multi-modal models. For multi-modal inputs, ensure the
1361
        prompt structure matches the model's expected input format.
1362
1363

        Args:
1364
1365
1366
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1367
                the `data_2` list.
1368
            data_2: The data to pair with the query to form the input to
1369
                the LLM. Can be text or multi-modal data. See [PromptType]
1370
                [vllm.inputs.PromptType] for more details about the format of
1371
                each prompt.
1372
1373
1374
1375
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1376
            lora_request: LoRA request to use for generation, if any.
1377
1378
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1379
        Returns:
1380
            A list of `ScoringRequestOutput` objects containing the
1381
1382
            generated scores in the same order as the input prompts.
        """
1383
1384
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1385
        if runner_type != "pooling":
1386
1387
1388
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1389
1390
                "pooling model."
            )
1391

1392
1393
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1394
1395
1396
1397
1398
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1399

1400
1401
1402
1403
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1404
            raise ValueError("Score API is only enabled for num_labels == 1.")
1405
1406
1407
1408

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1409
        tokenizer = self.get_tokenizer()
1410

1411
        if not model_config.is_multimodal_model:
1412

1413
1414
1415
1416
1417
            def check_data_type(
                data: Union[
                    SingletonPrompt, Sequence[SingletonPrompt], ScoreMultiModalParam
                ],
            ):
1418
                if isinstance(data, dict) and "content" in data:
1419
1420
1421
1422
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1423
1424
1425
1426
1427
1428
1429

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1430
1431
1432
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1433
1434
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1435
1436
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1465

1466
        if model_config.is_cross_encoder:
1467
1468
1469
1470
1471
1472
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1473
                pooling_params,
1474
1475
                lora_request,
            )
1476
        else:
1477
1478
            return self._embedding_score(
                tokenizer,
1479
1480
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1481
1482
                truncate_prompt_tokens,
                use_tqdm,
1483
                pooling_params,
1484
1485
                lora_request,
            )
1486

1487
1488
1489
1490
1491
1492
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1493
1494
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1495

1496
1497
1498
1499
1500
1501
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1502
        Args:
1503
1504
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1505
                is forgotten. Level 1 sleep is good for sleeping and waking
1506
1507
1508
1509
1510
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1511
                sleep is good for sleeping and waking up the engine to run a
1512
                different model or update the model, where previous model
1513
                weights are not needed. It reduces CPU memory pressure.
1514
        """
1515
        self.reset_prefix_cache()
1516
1517
        self.llm_engine.sleep(level=level)

1518
    def wake_up(self, tags: Optional[list[str]] = None):
1519
        """
1520
1521
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1522

1523
        Args:
1524
1525
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1526
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1527
                wake_up should be called with all tags (or None) before the
1528
1529
1530
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1531

1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1544
1545
    def _validate_and_add_requests(
        self,
1546
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
1547
1548
1549
1550
1551
1552
        params: Union[
            SamplingParams,
            Sequence[SamplingParams],
            PoolingParams,
            Sequence[PoolingParams],
        ],
1553
        *,
1554
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1555
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1556
        priority: Optional[list[int]] = None,
1557
    ) -> None:
1558
        if isinstance(prompts, (str, dict)):
1559
            # Convert a single prompt to a list.
1560
            prompts = [prompts]  # type: ignore[list-item]
1561

1562
        num_requests = len(prompts)
1563
        if isinstance(params, Sequence) and len(params) != num_requests:
1564
1565
1566
1567
1568
1569
1570
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )

        for sp in params if isinstance(params, Sequence) else (params,):
1571
1572
1573
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1574

Zhuohan Li's avatar
Zhuohan Li committed
1575
        # Add requests to the engine.
1576
1577
        it = prompts
        if use_tqdm:
1578
1579
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1580
1581

        for i, prompt in enumerate(it):
1582
1583
            if isinstance(prompt, dict):
                self._validate_mm_data_and_uuids(
1584
1585
                    prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                )
1586

1587
            self._add_request(
1588
                prompt,
1589
                params[i] if isinstance(params, Sequence) else params,
1590
1591
1592
                lora_request=lora_request[i]
                if isinstance(lora_request, Sequence)
                else lora_request,
1593
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1594
            )
1595

1596
    def _validate_mm_data_and_uuids(
1597
1598
1599
        self,
        multi_modal_data: Optional[Any],  # MultiModalDataDict
        multi_modal_uuids: Optional[Any],  # MultiModalUUIDDict
1600
1601
1602
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1603
        then its corresponding UUID must be set.
1604
1605
1606
1607
1608
1609
1610
1611
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1612
1613
1614
1615
1616
1617
1618
1619
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1620
1621
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1622
1623
                                f"but UUID is not provided"
                            )
1624
                        else:
1625
1626
1627
1628
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1629
1630
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1631
1632
                                    f"but UUID is not provided"
                                )
1633
            else:
1634
1635
1636
1637
1638
1639
1640
1641
1642
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1643

1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
        params: Union[SamplingParams, PoolingParams],
        *,
        lora_request: Optional[LoRARequest],
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
        tokenization_kwargs: dict[str, Any] = {}
1655
1656
1657
1658
1659
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671

        processor = self._get_processor()
        engine_request = processor.process_inputs(
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1672
    def _add_request(
nunjunj's avatar
nunjunj committed
1673
        self,
1674
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1675
        params: Union[SamplingParams, PoolingParams],
1676
        lora_request: Optional[LoRARequest] = None,
1677
        priority: int = 0,
1678
    ) -> None:
1679
        prompt_text, _, _ = get_prompt_components(prompt)
1680
        request_id = str(next(self.request_counter))
1681
1682

        engine_request, tokenization_kwargs = self._process_inputs(
1683
            request_id,
1684
            prompt,
1685
1686
            params,
            lora_request=lora_request,
1687
1688
1689
1690
1691
1692
1693
1694
            priority=priority,
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1695
            tokenization_kwargs=tokenization_kwargs,
1696
            priority=priority,
1697
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1698
        )
1699

1700
    def _run_engine(
1701
        self, *, use_tqdm: Union[bool, Callable[..., tqdm]] = True
1702
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1703
1704
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1705
            num_requests = self.llm_engine.get_num_unfinished_requests()
1706
1707
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1708
1709
1710
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1711
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1712
            )
1713

Zhuohan Li's avatar
Zhuohan Li committed
1714
        # Run the engine.
1715
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1716
1717
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1718
1719
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1720
            for output in step_outputs:
1721
                if output.finished:
1722
1723
                    outputs.append(output)
                    if use_tqdm:
1724
1725
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1726
                            n = len(output.outputs)
1727
                            assert output.prompt_token_ids is not None
1728
                            total_in_toks += len(output.prompt_token_ids) * n
1729
1730
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1731
1732
1733
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1734
1735
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1736
1737
                                f"output: {out_spd:.2f} toks/s"
                            )
1738
                            pbar.update(n)
1739
1740
                        else:
                            pbar.update(1)
1741
1742
                        if pbar.n == num_requests:
                            pbar.refresh()
1743

1744
1745
        if use_tqdm:
            pbar.close()
1746
1747
1748
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1749
        return sorted(outputs, key=lambda x: int(x.request_id))