llm.py 70.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar, deprecated
13

14
15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
    CompilationConfig,
22
    PoolerConfig,
23
24
25
    StructuredOutputsConfig,
    is_init_field,
)
26
from vllm.config.model import (
27
28
    ConvertOption,
    HfOverrides,
29
    ModelDType,
30
    RunnerOption,
31
    TokenizerMode,
32
)
33
from vllm.engine.arg_utils import EngineArgs
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
    apply_hf_chat_template,
    apply_mistral_chat_template,
    parse_chat_messages,
    resolve_chat_template_content_format,
)
from vllm.entrypoints.score_utils import (
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
58
from vllm.inputs.parse import get_prompt_components
59
from vllm.logger import init_logger
60
from vllm.lora.request import LoRARequest
61
from vllm.model_executor.layers.quantization import QuantizationMethods
62
63
64
65
66
67
68
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
69
from vllm.pooling_params import PoolingParams
70
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
71
from vllm.tasks import PoolingTask
72
73
74
75
76
from vllm.transformers_utils.tokenizer import (
    AnyTokenizer,
    MistralTokenizer,
    get_cached_tokenizer,
)
yhu422's avatar
yhu422 committed
77
from vllm.usage.usage_lib import UsageContext
78
from vllm.utils import Counter, Device, as_iter, is_list_of
79
from vllm.v1.engine import EngineCoreRequest
80
from vllm.v1.engine.llm_engine import LLMEngine
81
from vllm.v1.sample.logits_processor import LogitsProcessor
82

83
84
85
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

86
87
logger = init_logger(__name__)

88
89
_R = TypeVar("_R", default=Any)

90
91

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
96
97
98
99
100
101
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
102
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
103
104
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
105
106
107
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
108
109
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
110
111
112
113
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
114
        allowed_media_domains: If set, only media URLs that belong to this
115
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
116
117
118
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
119
120
121
122
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
123
        quantization: The method used to quantize the model weights. Currently,
124
            we support "awq", "gptq", and "fp8" (experimental).
125
126
127
128
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
129
130
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
131
132
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
133
134
135
136
137
138
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
139
140
141
142
143
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
144
            compared with using gpu_memory_utilization. Note that
145
146
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
147
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
148
149
150
151
152
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
153
154
155
156
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
157
158
159
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
160
161
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
162
        hf_token: The token to use as HTTP bearer authorization for remote files
163
            . If `True`, will use the token generated when running
164
            `huggingface-cli login` (stored in `~/.huggingface`).
165
166
167
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
168
169
170
171
172
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
173
174
175
176
177
        pooler_config: Initialize non-default pooling config for the pooling
            model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
        override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
            argument is deprecated and will be removed in v0.12.0 or v1.0.0,
            whichever is sooner.
178
179
180
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
181
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
182

183
184
    Note:
        This class is intended to be used for offline inference. For online
185
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
186
    """
187
188
189
190

    def __init__(
        self,
        model: str,
191
        *,
192
193
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
194
        tokenizer: str | None = None,
195
        tokenizer_mode: TokenizerMode = "auto",
196
        skip_tokenizer_init: bool = False,
197
        trust_remote_code: bool = False,
198
        allowed_local_media_path: str = "",
199
        allowed_media_domains: list[str] | None = None,
200
        tensor_parallel_size: int = 1,
201
        dtype: ModelDType = "auto",
202
203
204
205
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
        seed: int | None = None,
206
        gpu_memory_utilization: float = 0.9,
207
        swap_space: float = 4,
208
        cpu_offload_gb: float = 0,
209
        enforce_eager: bool = False,
210
        disable_custom_all_reduce: bool = False,
211
212
213
214
215
216
217
218
219
220
221
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        override_pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
222
        **kwargs: Any,
223
    ) -> None:
224
        """LLM constructor."""
225

226
227
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
228

229
230
231
232
233
234
235
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

236
        if "kv_transfer_config" in kwargs and isinstance(
237
238
            kwargs["kv_transfer_config"], dict
        ):
239
            from vllm.config.kv_transfer import KVTransferConfig
240

241
242
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
243
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
244
245
246
247
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
248
249
250
                    raw_config_dict,
                    e,
                )
251
252
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
253
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
254

255
256
257
        if hf_overrides is None:
            hf_overrides = {}

258
        if compilation_config is not None:
259
260
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
261
262
                    level=compilation_config
                )
263
264
            elif isinstance(compilation_config, dict):
                compilation_config_instance = CompilationConfig(
265
266
267
268
                    **{
                        k: v
                        for k, v in compilation_config.items()
                        if is_init_field(CompilationConfig, k)
269
270
                    }
                )
271
272
            else:
                compilation_config_instance = compilation_config
273
        else:
274
            compilation_config_instance = CompilationConfig()
275

276
277
278
279
280
281
282
        if structured_outputs_config is not None:
            if isinstance(structured_outputs_config, dict):
                structured_outputs_instance = StructuredOutputsConfig(
                    **{
                        k: v
                        for k, v in structured_outputs_config.items()
                        if is_init_field(StructuredOutputsConfig, k)
283
284
                    }
                )
285
286
287
288
289
            else:
                structured_outputs_instance = structured_outputs_config
        else:
            structured_outputs_instance = StructuredOutputsConfig()

Zhuohan Li's avatar
Zhuohan Li committed
290
        engine_args = EngineArgs(
291
            model=model,
292
293
            runner=runner,
            convert=convert,
294
            tokenizer=tokenizer,
295
            tokenizer_mode=tokenizer_mode,
296
            skip_tokenizer_init=skip_tokenizer_init,
297
            trust_remote_code=trust_remote_code,
298
            allowed_local_media_path=allowed_local_media_path,
299
            allowed_media_domains=allowed_media_domains,
300
301
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
302
            quantization=quantization,
303
            revision=revision,
304
            tokenizer_revision=tokenizer_revision,
305
306
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
307
            kv_cache_memory_bytes=kv_cache_memory_bytes,
308
            swap_space=swap_space,
309
            cpu_offload_gb=cpu_offload_gb,
310
            enforce_eager=enforce_eager,
311
            disable_custom_all_reduce=disable_custom_all_reduce,
312
            hf_token=hf_token,
313
            hf_overrides=hf_overrides,
314
            mm_processor_kwargs=mm_processor_kwargs,
315
            pooler_config=pooler_config,
316
            override_pooler_config=override_pooler_config,
317
            structured_outputs_config=structured_outputs_instance,
318
            compilation_config=compilation_config_instance,
319
            logits_processors=logits_processors,
320
321
            **kwargs,
        )
322

323
324
        log_non_default_args(engine_args)

325
326
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
327
328
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
329
        self.engine_class = type(self.llm_engine)
330

331
        self.request_counter = Counter()
332
        self.default_sampling_params: dict[str, Any] | None = None
333

334
335
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
336
337
        self.supported_tasks = supported_tasks

338
339
340
        self.model_config = self.llm_engine.model_config
        self.processor = self.llm_engine.processor
        self.io_processor = self.llm_engine.io_processor
341

342
343
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer()
344

345
    @deprecated("`set_tokenizer` is deprecated and will be removed in v0.13.")
346
    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
347
348
349
350
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
351
            self.llm_engine.tokenizer = tokenizer
352
        else:
353
            self.llm_engine.tokenizer = get_cached_tokenizer(tokenizer)
354

355
356
357
358
    def reset_mm_cache(self) -> None:
        self.processor.clear_mm_cache()
        self.llm_engine.reset_mm_cache()

359
    def get_default_sampling_params(self) -> SamplingParams:
360
        if self.default_sampling_params is None:
361
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
362
363
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
364
365
        return SamplingParams()

366
367
    def generate(
        self,
368
369
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
370
        *,
371
372
373
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
374
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
375
376
        """Generates the completions for the input prompts.

377
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
378
379
380
381
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
382
            prompts: The prompts to the LLM. You may pass a sequence of prompts
383
                for batch inference. See [PromptType][vllm.inputs.PromptType]
384
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
385
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
386
387
388
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
389
                prompts and it is paired one by one with the prompt.
390
391
392
393
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
394
            lora_request: LoRA request to use for generation, if any.
395
396
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
397
398

        Returns:
399
            A list of `RequestOutput` objects containing the
400
            generated completions in the same order as the input prompts.
401

402
403
404
405
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
406
        """
407
        model_config = self.model_config
408
409
        runner_type = model_config.runner_type
        if runner_type != "generate":
410
411
412
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
413
414
                "generative model."
            )
415

416
417
        if sampling_params is None:
            # Use default sampling params.
418
            sampling_params = self.get_default_sampling_params()
419

420
        # Add any modality specific loras to the corresponding prompts
421
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
422

423
        self._validate_and_add_requests(
424
            prompts=prompts,
425
            params=sampling_params,
426
            use_tqdm=use_tqdm,
427
            lora_request=lora_request,
428
429
            priority=priority,
        )
430

431
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
432
        return self.engine_class.validate_outputs(outputs, RequestOutput)
433

434
    def _get_modality_specific_lora_reqs(
435
        self,
436
437
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
438
    ):
439
440
441
442
443
444
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
445
446
        if (
            lora_config is None
447
            or not self.model_config.is_multimodal_model
448
449
            or (lora_config and lora_config.default_mm_loras is None)
        ):
450
451
            return lora_request

452
453
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
454

455
456
457
458
459
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
460
461
462

        return [
            self._resolve_single_prompt_mm_lora(
463
                prompt,
464
465
                opt_lora_req,
                lora_config.default_mm_loras,
466
467
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
468
469
        ]

470
471
472
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
473
474
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
475
476
477
478
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
479
            or not (mm_data := prompt.get("multi_modal_data") or {})
480
        ):
481
482
            return lora_request

483
484
485
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
486
487
488
489
490
491
492
493
494
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
495
496
497
                " will be skipped",
                intersection,
            )
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
513
514
                    "lora_request as we only apply one LoRARequest per prompt"
                )
515
516
517
518
519
520
521
522
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

523
524
    def collective_rpc(
        self,
525
526
        method: str | Callable[..., _R],
        timeout: float | None = None,
527
        args: tuple = (),
528
        kwargs: dict[str, Any] | None = None,
529
    ) -> list[_R]:
530
531
532
533
534
535
536
537
538
539
540
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
541
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
542
543
544
545
546
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
547

548
549
550
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
551
        """
552
553

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
554
555

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
556
        """
557
558
        Run a function directly on the model inside each worker,
        returning the result for each of them.
559
560
561
562
563
564

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
565
        """
566
        return self.llm_engine.apply_model(func)
567

568
569
    def _get_beam_search_lora_requests(
        self,
570
571
572
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
573
        """Get the optional lora request corresponding to each prompt."""
574
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
575
            raise ValueError(
576
577
                "Lora request list should be the same length as the prompts"
            )
578
579
580
581
582
583

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

584
585
    def beam_search(
        self,
586
        prompts: list[TokensPrompt | TextPrompt],
587
        params: BeamSearchParams,
588
        lora_request: list[LoRARequest] | LoRARequest | None = None,
589
        use_tqdm: bool = False,
590
        concurrency_limit: int | None = None,
591
    ) -> list[BeamSearchOutput]:
592
593
594
595
596
597
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
598
            params: The beam search parameters.
599
            lora_request: LoRA request to use for generation, if any.
600
            use_tqdm: Whether to use tqdm to display the progress bar.
601
602
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
603
        """
604
605
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
606
607
608
609
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
610
611
        length_penalty = params.length_penalty

612
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
613

614
615
616
617
618
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
619

620
621
622
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
623
624
                "Disabling progress bar."
            )
625
626
627
628
629
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

630
631
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
632
633
634
635
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
636
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
637
            return TokensPrompt(**token_prompt_kwargs)
638

639
640
641
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
642
643
644
        beam_search_params = SamplingParams(
            logprobs=2 * beam_width, max_tokens=1, temperature=temperature
        )
645
        instances: list[BeamSearchInstance] = []
646

647
        for lora_req, prompt in zip(lora_requests, prompts):
648
649
650
651
652
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
653
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
654

655
656
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
657
658
659
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
660

661
            instances.append(
662
663
664
665
666
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
667
668
                ),
            )
669

670
        for prompt_start in range(0, len(prompts), concurrency_limit):
671
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
672
673
674

            token_iter = range(max_tokens)
            if use_tqdm:
675
676
677
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
678
679
680
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
681
682
                    "reflect instance-level progress."
                )
683
684
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
685
686
                    sum((instance.beams for instance in instances_batch), [])
                )
687
688
                pos = [0] + list(
                    itertools.accumulate(
689
690
691
                        len(instance.beams) for instance in instances_batch
                    )
                )
692
                instance_start_and_end: list[tuple[int, int]] = list(
693
694
                    zip(pos[:-1], pos[1:])
                )
695
696
697
698
699
700

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
701
702
703
704
705
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
706
707
708

                # only runs for one step
                # we don't need to use tqdm here
709
710
711
712
713
714
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
715

716
717
718
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
719
720
721
722
723
724
725
726
727
728
729
730
731
732
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
733
                                    logprobs=current_beam.logprobs + [logprobs],
734
                                    lora_request=current_beam.lora_request,
735
736
737
738
739
740
741
742
743
744
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
745
746
747
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
748
749
750
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
751
                    instance.beams = sorted_beams[:beam_width]
752
753
754
755

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
756
757
758
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
759
760
761
762
763
764
765
766
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

767
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
768
        self,
769
770
771
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
772
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
773
        add_generation_prompt: bool = True,
774
        continue_final_message: bool = False,
775
776
777
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
778
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
779
        """
780
781
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
782

783
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
784
        Returns:
785
786
787
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
788
        """
789
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
790

791
792
        # Handle multi and single conversations
        if is_list_of(messages, list):
793
            # messages is list[list[...]]
794
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
795
        else:
796
            # messages is list[...]
797
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
798

799
        tokenizer = self.get_tokenizer()
800
        model_config = self.model_config
801
802
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
803
            tools,
804
805
            chat_template_content_format,
            tokenizer,
806
            model_config=model_config,
807
808
        )

809
810
811
812
813
814
815
816
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

817
        prompts: list[TokensPrompt] = []
818
819

        for msgs in list_of_messages:
820
821
822
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
823
            conversation, mm_data, mm_uuids = parse_chat_messages(
824
825
826
827
828
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
829
830

            if isinstance(tokenizer, MistralTokenizer):
831
                prompt_token_ids = apply_mistral_chat_template(
832
833
                    tokenizer,
                    messages=msgs,
834
                    **_chat_template_kwargs,
835
836
                )
            else:
837
                prompt_str = apply_hf_chat_template(
838
                    tokenizer=tokenizer,
839
                    conversation=conversation,
840
                    model_config=model_config,
841
                    **_chat_template_kwargs,
842
                )
843
844
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
845
846
847
                prompt_token_ids = tokenizer.encode(
                    prompt_str, add_special_tokens=False
                )
848

849
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
850
851
852
853

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

854
855
856
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

857
858
859
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

860
            prompts.append(prompt)
861

862
863
864
865
        return prompts

    def chat(
        self,
866
867
868
869
870
871
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
872
873
874
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
875
876
877
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
941
        return self.generate(
942
            prompts,
943
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
944
945
946
947
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

948
949
    def encode(
        self,
950
951
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
952
        *,
953
954
955
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
956
        pooling_task: PoolingTask = "encode",
957
        tokenization_kwargs: dict[str, Any] | None = None,
958
    ) -> list[PoolingRequestOutput]:
959
960
        """Apply pooling to the hidden states corresponding to the input
        prompts.
961

962
        This class automatically batches the given prompts, considering
963
964
965
966
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
967
            prompts: The prompts to the LLM. You may pass a sequence of prompts
968
                for batch inference. See [PromptType][vllm.inputs.PromptType]
969
                for more details about the format of each prompt.
970
971
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
972
973
974
975
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
976
            lora_request: LoRA request to use for generation, if any.
977
            pooling_task: Override the pooling task to use.
978
979
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
980
981

        Returns:
982
            A list of `PoolingRequestOutput` objects containing the
983
            pooled hidden states in the same order as the input prompts.
984

985
986
987
988
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
989
        """
990
991
992
993

        if self.supported_tasks == ["encode"] and pooling_task is None:
            pooling_task = "encode"

994
        if pooling_task is None:
995
            pooling_task = "embed" if "embed" in self.supported_tasks else "encode"
996
997
998
999
1000
1001

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
1002
                'or `pooling_task="embed"`.\n'
1003
                "  - For classification logits, use `LLM.classify(...)` "
1004
                'or `pooling_task="classify"`.\n'
1005
                "  - For rewards, use `LLM.reward(...)` "
1006
                'or `pooling_task="reward"`\n'
1007
                "  - For similarity scores, use `LLM.score(...)`.",
1008
1009
                pooling_task,
            )
1010

1011
        model_config = self.model_config
1012
        runner_type = model_config.runner_type
1013
        if runner_type != "pooling":
1014
1015
1016
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1017
1018
                "pooling model."
            )
1019

1020
        if pooling_task not in self.supported_tasks:
1021
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")
1022

1023
1024
1025
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1026

1027
1028
1029
1030
1031
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
1032

1033
1034
1035
1036
1037
1038
1039
1040
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1041
1042
                    "offline inference example for more details."
                )
1043
1044
1045
1046
1047
1048
1049

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1050
        self._validate_and_add_requests(
1051
            prompts=prompts,
1052
            params=pooling_params,
1053
            use_tqdm=use_tqdm,
1054
            lora_request=lora_request,
1055
1056
        )

1057
        outputs = self._run_engine(use_tqdm=use_tqdm)
1058
1059

        model_outputs = self.engine_class.validate_outputs(
1060
1061
            outputs, PoolingRequestOutput
        )
1062
1063
1064
1065
1066

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1067
1068
                model_output=model_outputs
            )
1069
1070

            return [
1071
1072
1073
1074
1075
1076
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
                    prompt_token_ids=[],
                    finished=True,
                )
1077
1078
1079
            ]
        else:
            return model_outputs
1080

1081
1082
    def embed(
        self,
1083
        prompts: PromptType | Sequence[PromptType],
1084
        *,
1085
1086
1087
1088
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1089
    ) -> list[EmbeddingRequestOutput]:
1090
1091
1092
1093
1094
1095
1096
1097
1098
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1099
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1100
                for more details about the format of each prompt.
1101
1102
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1103
1104
1105
1106
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1107
1108
1109
            lora_request: LoRA request to use for generation, if any.

        Returns:
1110
            A list of `EmbeddingRequestOutput` objects containing the
1111
1112
            embedding vectors in the same order as the input prompts.
        """
1113
        if "embed" not in self.supported_tasks:
1114
1115
            raise ValueError(
                "Embedding API is not supported by this model. "
1116
1117
                "Try converting the model using `--convert embed`."
            )
1118

1119
1120
1121
1122
1123
1124
1125
1126
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1127
1128
1129
1130
1131

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1132
        prompts: PromptType | Sequence[PromptType],
1133
        *,
1134
1135
1136
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1137
    ) -> list[ClassificationRequestOutput]:
1138
1139
1140
1141
1142
1143
1144
1145
1146
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1147
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1148
                for more details about the format of each prompt.
1149
1150
1151
1152
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1153
            lora_request: LoRA request to use for generation, if any.
1154
1155
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1156
        Returns:
1157
            A list of `ClassificationRequestOutput` objects containing the
1158
1159
            embedding vectors in the same order as the input prompts.
        """
1160
        if "classify" not in self.supported_tasks:
1161
            raise ValueError(
1162
                "Classification API is not supported by this model. "
1163
1164
                "Try converting the model using `--convert classify`."
            )
1165

1166
1167
1168
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1169
            pooling_params=pooling_params,
1170
1171
1172
            lora_request=lora_request,
            pooling_task="classify",
        )
1173
1174
1175

        return [ClassificationRequestOutput.from_base(item) for item in items]

1176
1177
    def reward(
        self,
1178
        prompts: PromptType | Sequence[PromptType],
1179
1180
        /,
        *,
1181
1182
1183
1184
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1185
1186
1187
1188
1189
1190
1191
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1192
                for more details about the format of each prompt.
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1214
1215
1216
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1217
1218
1219
1220
1221
1222
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1223
1224
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1225
            text_1 + text_2,
1226
            truncate_prompt_tokens=truncate_prompt_tokens,
1227
1228
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1229
            pooling_params=pooling_params,
1230
1231
            pooling_task="embed",
        )
1232

1233
1234
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1235
1236
1237
1238

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1239
1240
1241
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1242

1243
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1244
1245
1246
1247
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1248
        tokenizer: AnyTokenizer,
1249
1250
1251
1252
1253
1254
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1255
    ) -> list[ScoringRequestOutput]:
1256
        model_config = self.model_config
1257
1258

        if isinstance(tokenizer, MistralTokenizer):
1259
            raise ValueError("Score API is not supported for Mistral tokenizer")
1260

1261
1262
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1263

1264
1265
1266
1267
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1268
        pooling_params_list = list[PoolingParams]()
1269

1270
        tokenization_kwargs: dict[str, Any] = {}
1271

1272
1273
1274
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1275

1276
        prompts = list[PromptType]()
1277

1278
1279
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1280
1281
1282
1283
1284
1285
1286
1287
1288
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1289
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1290
1291
1292
1293
1294
1295
1296
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1297
            prompts.append(engine_prompt)
1298
1299

        self._validate_and_add_requests(
1300
            prompts=prompts,
1301
            params=pooling_params_list,
1302
            use_tqdm=use_tqdm,
1303
1304
1305
1306
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1307
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1308
1309
1310

        return [ScoringRequestOutput.from_base(item) for item in items]

1311
1312
    def score(
        self,
1313
1314
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1315
        /,
1316
        *,
1317
1318
1319
1320
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1321
    ) -> list[ScoringRequestOutput]:
1322
1323
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1324

1325
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1326
1327
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1328
        The input pairs are used to build a list of prompts for the
1329
1330
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1331
1332
1333
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1334
        appropriate multi-modal models. For multi-modal inputs, ensure the
1335
        prompt structure matches the model's expected input format.
1336
1337

        Args:
1338
1339
1340
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1341
                the `data_2` list.
1342
            data_2: The data to pair with the query to form the input to
1343
                the LLM. Can be text or multi-modal data. See [PromptType]
1344
                [vllm.inputs.PromptType] for more details about the format of
1345
                each prompt.
1346
1347
1348
1349
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1350
            lora_request: LoRA request to use for generation, if any.
1351
1352
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1353
        Returns:
1354
            A list of `ScoringRequestOutput` objects containing the
1355
1356
            generated scores in the same order as the input prompts.
        """
1357
        model_config = self.model_config
1358
        runner_type = model_config.runner_type
1359
        if runner_type != "pooling":
1360
1361
1362
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1363
1364
                "pooling model."
            )
1365

1366
1367
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1368
1369
1370
1371
1372
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1373

1374
1375
1376
1377
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1378
            raise ValueError("Score API is only enabled for num_labels == 1.")
1379
1380
1381
1382

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1383
        tokenizer = self.get_tokenizer()
1384

1385
        if not model_config.is_multimodal_model:
1386

1387
            def check_data_type(
1388
1389
1390
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1391
            ):
1392
                if isinstance(data, dict) and "content" in data:
1393
1394
1395
1396
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1397
1398
1399
1400
1401
1402
1403

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1404
1405
1406
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1407
1408
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1409
1410
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1439

1440
        if model_config.is_cross_encoder:
1441
1442
1443
1444
1445
1446
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1447
                pooling_params,
1448
1449
                lora_request,
            )
1450
        else:
1451
1452
            return self._embedding_score(
                tokenizer,
1453
1454
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1455
1456
                truncate_prompt_tokens,
                use_tqdm,
1457
                pooling_params,
1458
1459
                lora_request,
            )
1460

1461
1462
1463
1464
1465
1466
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1467
    def reset_prefix_cache(self, device: Device | None = None) -> bool:
1468
        return self.llm_engine.reset_prefix_cache(device)
1469

1470
1471
1472
1473
1474
1475
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1476
        Args:
1477
1478
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1479
                is forgotten. Level 1 sleep is good for sleeping and waking
1480
1481
1482
1483
1484
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1485
                sleep is good for sleeping and waking up the engine to run a
1486
                different model or update the model, where previous model
1487
                weights are not needed. It reduces CPU memory pressure.
1488
        """
1489
        self.reset_prefix_cache()
1490
1491
        self.llm_engine.sleep(level=level)

1492
    def wake_up(self, tags: list[str] | None = None):
1493
        """
1494
1495
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1496

1497
        Args:
1498
1499
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1500
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1501
                wake_up should be called with all tags (or None) before the
1502
1503
1504
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1505

1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1518
1519
    def _validate_and_add_requests(
        self,
1520
1521
1522
1523
1524
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1525
        *,
1526
1527
1528
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1529
    ) -> None:
1530
        if isinstance(prompts, (str, dict)):
1531
            # Convert a single prompt to a list.
1532
            prompts = [prompts]  # type: ignore[list-item]
1533

1534
        num_requests = len(prompts)
1535
        if isinstance(params, Sequence) and len(params) != num_requests:
1536
1537
1538
1539
1540
1541
1542
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )

        for sp in params if isinstance(params, Sequence) else (params,):
1543
1544
1545
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1546

Zhuohan Li's avatar
Zhuohan Li committed
1547
        # Add requests to the engine.
1548
1549
        it = prompts
        if use_tqdm:
1550
1551
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1552
1553

        for i, prompt in enumerate(it):
1554
1555
            if isinstance(prompt, dict):
                self._validate_mm_data_and_uuids(
1556
1557
                    prompt.get("multi_modal_data"), prompt.get("multi_modal_uuids")
                )
1558

1559
            self._add_request(
1560
                prompt,
1561
                params[i] if isinstance(params, Sequence) else params,
1562
1563
1564
                lora_request=lora_request[i]
                if isinstance(lora_request, Sequence)
                else lora_request,
1565
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1566
            )
1567

1568
    def _validate_mm_data_and_uuids(
1569
        self,
1570
1571
        multi_modal_data: Any | None,  # MultiModalDataDict
        multi_modal_uuids: Any | None,  # MultiModalUUIDDict
1572
1573
1574
    ):
        """
        Validate that if any multi-modal data is skipped (i.e. None),
1575
        then its corresponding UUID must be set.
1576
1577
1578
1579
1580
1581
1582
1583
        """
        if multi_modal_data is None:
            return

        for modality, data in multi_modal_data.items():
            if isinstance(data, list):
                for i, d in enumerate(data):
                    if d is None:
1584
1585
1586
1587
1588
1589
1590
1591
                        if (
                            multi_modal_uuids is None
                            or modality not in multi_modal_uuids
                            or multi_modal_uuids[  # noqa: E501
                                modality
                            ]
                            is None
                        ):
1592
1593
                            raise ValueError(
                                f"Multi-modal data for {modality} is None "
1594
1595
                                f"but UUID is not provided"
                            )
1596
                        else:
1597
1598
1599
1600
                            if (
                                len(multi_modal_uuids[modality]) <= i
                                or multi_modal_uuids[modality][i] is None
                            ):
1601
1602
                                raise ValueError(
                                    f"Multi-modal data for {modality} is None "
1603
1604
                                    f"but UUID is not provided"
                                )
1605
            else:
1606
1607
1608
1609
1610
1611
1612
1613
1614
                if data is None and (
                    multi_modal_uuids is None
                    or modality not in multi_modal_uuids
                    or multi_modal_uuids[modality] is None
                ):
                    raise ValueError(
                        f"Multi-modal data for {modality} is None"
                        f" but UUID is not provided"
                    )
1615

1616
1617
1618
1619
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1620
        params: SamplingParams | PoolingParams,
1621
        *,
1622
        lora_request: LoRARequest | None,
1623
1624
1625
1626
        priority: int,
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
        tokenization_kwargs: dict[str, Any] = {}
1627
1628
1629
1630
1631
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1632

1633
        engine_request = self.processor.process_inputs(
1634
1635
1636
1637
1638
1639
1640
1641
1642
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs

1643
    def _add_request(
nunjunj's avatar
nunjunj committed
1644
        self,
1645
        prompt: PromptType,
1646
1647
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1648
        priority: int = 0,
1649
    ) -> None:
1650
        prompt_text, _, _ = get_prompt_components(prompt)
1651
        request_id = str(next(self.request_counter))
1652
1653

        engine_request, tokenization_kwargs = self._process_inputs(
1654
            request_id,
1655
            prompt,
1656
1657
            params,
            lora_request=lora_request,
1658
1659
1660
1661
1662
1663
1664
1665
            priority=priority,
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1666
            tokenization_kwargs=tokenization_kwargs,
1667
            priority=priority,
1668
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1669
        )
1670

1671
    def _run_engine(
1672
1673
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1674
1675
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1676
            num_requests = self.llm_engine.get_num_unfinished_requests()
1677
1678
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1679
1680
1681
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1682
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1683
            )
1684

Zhuohan Li's avatar
Zhuohan Li committed
1685
        # Run the engine.
1686
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1687
1688
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1689
1690
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1691
            for output in step_outputs:
1692
                if output.finished:
1693
1694
                    outputs.append(output)
                    if use_tqdm:
1695
1696
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1697
                            n = len(output.outputs)
1698
                            assert output.prompt_token_ids is not None
1699
                            total_in_toks += len(output.prompt_token_ids) * n
1700
1701
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1702
1703
1704
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1705
1706
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1707
1708
                                f"output: {out_spd:.2f} toks/s"
                            )
1709
                            pbar.update(n)
1710
1711
                        else:
                            pbar.update(1)
1712
1713
                        if pbar.n == num_requests:
                            pbar.refresh()
1714

1715
1716
        if use_tqdm:
            pbar.close()
1717
1718
1719
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1720
        return sorted(outputs, key=lambda x: int(x.request_id))