llm.py 71.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
6
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
15
16
17
18
19
20
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
21
    AttentionConfig,
22
    CompilationConfig,
23
    PoolerConfig,
24
    ProfilerConfig,
25
26
27
    StructuredOutputsConfig,
    is_init_field,
)
28
from vllm.config.compilation import CompilationMode
29
from vllm.config.model import (
30
31
    ConvertOption,
    HfOverrides,
32
    ModelDType,
33
    RunnerOption,
34
    TokenizerMode,
35
)
36
from vllm.engine.arg_utils import EngineArgs
37
38
39
40
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
41
from vllm.entrypoints.pooling.score.utils import (
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    ScoreContentPartParam,
    ScoreMultiModalParam,
    _cosine_similarity,
    _validate_score_input_lens,
    compress_token_type_ids,
    get_score_prompt,
)
from vllm.entrypoints.utils import _validate_truncation_size, log_non_default_args
from vllm.inputs import (
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
57
from vllm.inputs.parse import get_prompt_components
58
from vllm.logger import init_logger
59
from vllm.lora.request import LoRARequest
60
from vllm.model_executor.layers.quantization import QuantizationMethods
61
62
63
64
65
66
67
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
68
from vllm.platforms import current_platform
69
from vllm.pooling_params import PoolingParams
70
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
71
from vllm.tasks import PoolingTask
72
73
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
74
from vllm.usage.usage_lib import UsageContext
75
from vllm.utils.collection_utils import as_iter, is_list_of
76
from vllm.utils.counter import Counter
77
from vllm.v1.engine import EngineCoreRequest
78
from vllm.v1.engine.llm_engine import LLMEngine
79
from vllm.v1.sample.logits_processor import LogitsProcessor
80

81
82
83
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

84
85
logger = init_logger(__name__)

86
87
_R = TypeVar("_R", default=Any)

88
89

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
94
95
96
97
98
99
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
100
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
101
102
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
103
104
105
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
106
107
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
108
109
110
111
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
112
        allowed_media_domains: If set, only media URLs that belong to this
113
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
114
115
116
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
117
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
118
119
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
120
        quantization: The method used to quantize the model weights. Currently,
121
            we support "awq", "gptq", and "fp8" (experimental).
122
123
124
125
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
126
127
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
128
129
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
130
131
132
133
134
135
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
136
137
138
139
140
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
141
            compared with using gpu_memory_utilization. Note that
142
143
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
144
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
145
146
147
148
149
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
150
151
152
153
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
154
155
156
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
157
        enable_return_routed_experts: Whether to return routed experts.
158
159
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
160
        hf_token: The token to use as HTTP bearer authorization for remote files
161
            . If `True`, will use the token generated when running
162
            `huggingface-cli login` (stored in `~/.huggingface`).
163
164
165
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
166
167
168
169
170
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
171
        pooler_config: Initialize non-default pooling config for the pooling
172
            model. e.g. `PoolerConfig(seq_pooling_type="MEAN", normalize=False)`.
173
        compilation_config: Either an integer or a dictionary. If it is an
174
            integer, it is used as the mode of compilation optimization. If it
175
            is a dictionary, it can specify the full compilation configuration.
176
177
178
179
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
180
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
181

182
183
    Note:
        This class is intended to be used for offline inference. For online
184
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
185
    """
186
187
188
189

    def __init__(
        self,
        model: str,
190
        *,
191
192
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
193
        tokenizer: str | None = None,
194
        tokenizer_mode: TokenizerMode | str = "auto",
195
        skip_tokenizer_init: bool = False,
196
        trust_remote_code: bool = False,
197
        allowed_local_media_path: str = "",
198
        allowed_media_domains: list[str] | None = None,
199
        tensor_parallel_size: int = 1,
200
        dtype: ModelDType = "auto",
201
202
203
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
204
        seed: int = 0,
205
        gpu_memory_utilization: float = 0.9,
206
        swap_space: float = 4,
207
        cpu_offload_gb: float = 0,
208
        enforce_eager: bool = False,
209
        enable_return_routed_experts: bool = False,
210
        disable_custom_all_reduce: bool = False,
211
212
213
214
215
216
217
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
218
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
219
        attention_config: dict[str, Any] | AttentionConfig | None = None,
220
221
222
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
223
        **kwargs: Any,
224
    ) -> None:
225
        """LLM constructor."""
226

227
228
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
229

230
231
232
233
234
235
236
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

237
        if "kv_transfer_config" in kwargs and isinstance(
238
239
            kwargs["kv_transfer_config"], dict
        ):
240
            from vllm.config.kv_transfer import KVTransferConfig
241

242
243
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
244
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
245
246
247
248
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
249
250
251
                    raw_config_dict,
                    e,
                )
252
253
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
254
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
255

256
257
258
        if hf_overrides is None:
            hf_overrides = {}

259
260
261
262
263
264
265
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
266

267
268
269
270
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
271
        else:
272
273
274
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
275

276
277
278
279
280
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
281

282
        # warn about single-process data parallel usage.
283
284
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
285
286
287
288
289
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
290
            raise ValueError(
291
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
292
293
294
295
296
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
297
        engine_args = EngineArgs(
298
            model=model,
299
300
            runner=runner,
            convert=convert,
301
            tokenizer=tokenizer,
302
            tokenizer_mode=tokenizer_mode,
303
            skip_tokenizer_init=skip_tokenizer_init,
304
            trust_remote_code=trust_remote_code,
305
            allowed_local_media_path=allowed_local_media_path,
306
            allowed_media_domains=allowed_media_domains,
307
308
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
309
            quantization=quantization,
310
            revision=revision,
311
            tokenizer_revision=tokenizer_revision,
312
313
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
314
            kv_cache_memory_bytes=kv_cache_memory_bytes,
315
            swap_space=swap_space,
316
            cpu_offload_gb=cpu_offload_gb,
317
            enforce_eager=enforce_eager,
318
            enable_return_routed_experts=enable_return_routed_experts,
319
            disable_custom_all_reduce=disable_custom_all_reduce,
320
            hf_token=hf_token,
321
            hf_overrides=hf_overrides,
322
            mm_processor_kwargs=mm_processor_kwargs,
323
            pooler_config=pooler_config,
324
            structured_outputs_config=structured_outputs_instance,
325
            profiler_config=profiler_config_instance,
326
            attention_config=attention_config_instance,
327
            compilation_config=compilation_config_instance,
328
            logits_processors=logits_processors,
329
330
            **kwargs,
        )
331

332
333
        log_non_default_args(engine_args)

334
        self.llm_engine = LLMEngine.from_engine_args(
335
336
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
337
        self.engine_class = type(self.llm_engine)
338

339
        self.request_counter = Counter()
340
        self.default_sampling_params: dict[str, Any] | None = None
341

342
343
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
344
345
        self.supported_tasks = supported_tasks

346
        self.model_config = self.llm_engine.model_config
347
        self.input_processor = self.llm_engine.input_processor
348
        self.io_processor = self.llm_engine.io_processor
349

350
351
352
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

353
    def get_tokenizer(self) -> TokenizerLike:
354
        return self.llm_engine.get_tokenizer()
355

356
    def reset_mm_cache(self) -> None:
357
        self.input_processor.clear_mm_cache()
358
359
        self.llm_engine.reset_mm_cache()

360
    def get_default_sampling_params(self) -> SamplingParams:
361
        if self.default_sampling_params is None:
362
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
363
364
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
365
366
        return SamplingParams()

367
368
    def generate(
        self,
369
370
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
371
        *,
372
373
374
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
375
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
376
377
        """Generates the completions for the input prompts.

378
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
379
380
381
382
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
383
            prompts: The prompts to the LLM. You may pass a sequence of prompts
384
                for batch inference. See [PromptType][vllm.inputs.PromptType]
385
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
386
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
387
388
389
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
390
                prompts and it is paired one by one with the prompt.
391
392
393
394
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
395
            lora_request: LoRA request to use for generation, if any.
396
397
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
398
399
400
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
Woosuk Kwon's avatar
Woosuk Kwon committed
401
402

        Returns:
403
            A list of `RequestOutput` objects containing the
404
            generated completions in the same order as the input prompts.
405
406

        Note:
407
            Using `prompts` and `prompt_token_ids` as keyword parameters is
408
            considered legacy and may be deprecated in the future. You should
409
            instead pass them via the `inputs` parameter.
410
        """
411
        model_config = self.model_config
412
413
        runner_type = model_config.runner_type
        if runner_type != "generate":
414
415
416
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
417
418
                "generative model."
            )
419

420
421
        if sampling_params is None:
            # Use default sampling params.
422
            sampling_params = self.get_default_sampling_params()
423

424
        # Add any modality specific loras to the corresponding prompts
425
        lora_request = self._get_modality_specific_lora_reqs(prompts, lora_request)
426

427
        self._validate_and_add_requests(
428
            prompts=prompts,
429
            params=sampling_params,
430
            use_tqdm=use_tqdm,
431
            lora_request=lora_request,
432
433
            priority=priority,
        )
434

435
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
436
        return self.engine_class.validate_outputs(outputs, RequestOutput)
437

438
    def _get_modality_specific_lora_reqs(
439
        self,
440
441
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
442
    ):
443
444
445
446
447
448
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
449
450
        if (
            lora_config is None
451
            or not self.model_config.is_multimodal_model
452
453
            or (lora_config and lora_config.default_mm_loras is None)
        ):
454
455
            return lora_request

456
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
457
            prompts = [prompts]
458

459
460
461
462
463
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
464
465
466

        return [
            self._resolve_single_prompt_mm_lora(
467
                prompt,
468
469
                opt_lora_req,
                lora_config.default_mm_loras,
470
471
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
472
473
        ]

474
475
476
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
477
478
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
479
480
481
482
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
483
            or not (mm_data := prompt.get("multi_modal_data") or {})
484
        ):
485
486
            return lora_request

487
488
489
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
490
491
492
493
494
495
496
497
498
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
499
500
501
                " will be skipped",
                intersection,
            )
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
517
518
                    "lora_request as we only apply one LoRARequest per prompt"
                )
519
520
521
522
523
524
525
526
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

527
528
    def collective_rpc(
        self,
529
530
        method: str | Callable[..., _R],
        timeout: float | None = None,
531
        args: tuple = (),
532
        kwargs: dict[str, Any] | None = None,
533
    ) -> list[_R]:
534
535
536
537
538
539
540
541
542
543
544
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
545
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
546
547
548
549
550
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
551

552
553
554
555
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
556
557

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
558
559

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
560
        """
561
562
        Run a function directly on the model inside each worker,
        returning the result for each of them.
563
564
565
566
567
568

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
569
        """
570
        return self.llm_engine.apply_model(func)
571

572
573
    def _get_beam_search_lora_requests(
        self,
574
575
576
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
577
        """Get the optional lora request corresponding to each prompt."""
578
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
579
            raise ValueError(
580
581
                "Lora request list should be the same length as the prompts"
            )
582
583
584
585
586
587

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

588
589
    def beam_search(
        self,
590
        prompts: list[TokensPrompt | TextPrompt],
591
        params: BeamSearchParams,
592
        lora_request: list[LoRARequest] | LoRARequest | None = None,
593
        use_tqdm: bool = False,
594
        concurrency_limit: int | None = None,
595
    ) -> list[BeamSearchOutput]:
596
597
598
599
600
601
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
602
            params: The beam search parameters.
603
            lora_request: LoRA request to use for generation, if any.
604
            use_tqdm: Whether to use tqdm to display the progress bar.
605
606
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
607
        """
608
609
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
610
611
612
613
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
614
615
        length_penalty = params.length_penalty

616
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
617

618
619
620
621
622
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
623

624
625
626
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
627
628
                "Disabling progress bar."
            )
629
630
631
632
633
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

634
635
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
636
637
638
639
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
640
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
641
            return TokensPrompt(**token_prompt_kwargs)
642

643
644
645
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
646
        beam_search_params = SamplingParams(
647
648
649
650
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
651
        )
652
        instances: list[BeamSearchInstance] = []
653

654
        for lora_req, prompt in zip(lora_requests, prompts):
655
656
657
658
659
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
660
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
661

662
663
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
664
665
666
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
667

668
            instances.append(
669
670
671
672
673
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
674
675
                ),
            )
676

677
        for prompt_start in range(0, len(prompts), concurrency_limit):
678
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
679
680
681

            token_iter = range(max_tokens)
            if use_tqdm:
682
683
684
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
685
686
687
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
688
689
                    "reflect instance-level progress."
                )
690
691
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
692
693
                    sum((instance.beams for instance in instances_batch), [])
                )
694
695
                pos = [0] + list(
                    itertools.accumulate(
696
697
698
                        len(instance.beams) for instance in instances_batch
                    )
                )
699
                instance_start_and_end: list[tuple[int, int]] = list(
700
701
                    zip(pos[:-1], pos[1:])
                )
702
703
704
705
706
707

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
708
709
710
711
712
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
713
714
715

                # only runs for one step
                # we don't need to use tqdm here
716
717
718
719
720
721
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
722

723
724
725
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
726
727
728
729
730
731
732
733
734
735
736
737
738
739
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
740
                                    logprobs=current_beam.logprobs + [logprobs],
741
                                    lora_request=current_beam.lora_request,
742
743
744
745
746
747
748
749
750
751
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
752
753
754
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
755
756
757
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
758
                    instance.beams = sorted_beams[:beam_width]
759
760
761
762

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
763
764
765
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
766
767
768
769
770
771
772
773
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

774
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
775
        self,
776
777
778
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
779
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
780
        add_generation_prompt: bool = True,
781
        continue_final_message: bool = False,
782
783
784
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
785
    ) -> list[TextPrompt | TokensPrompt]:
nunjunj's avatar
nunjunj committed
786
        """
787
788
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
789

790
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
791
        Returns:
792
793
794
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
795
        """
796
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
797

798
799
        # Handle multi and single conversations
        if is_list_of(messages, list):
800
            # messages is list[list[...]]
801
            list_of_messages = cast(list[list[ChatCompletionMessageParam]], messages)
802
        else:
803
            # messages is list[...]
804
            list_of_messages = [cast(list[ChatCompletionMessageParam], messages)]
805

806
        renderer = self.llm_engine.renderer
807

808
809
810
811
812
813
814
        chat_template_kwargs = {
            "chat_template": chat_template,
            "add_generation_prompt": add_generation_prompt,
            "continue_final_message": continue_final_message,
            "tools": tools,
            **(chat_template_kwargs or {}),
        }
815

816
        prompts = list[TextPrompt | TokensPrompt]()
817
818

        for msgs in list_of_messages:
819
            # NOTE: renderer.render_messages() currently doesn't
820
821
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
822
            _, prompt = renderer.render_messages(
823
                msgs,
824
825
                chat_template_content_format=chat_template_content_format,
                **chat_template_kwargs,
826
            )
827
828
829
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

830
            prompts.append(prompt)
831

832
833
834
835
        return prompts

    def chat(
        self,
836
837
838
839
840
841
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
842
843
844
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
845
846
847
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
911
        return self.generate(
912
            prompts,
913
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
914
915
916
917
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

918
919
    def encode(
        self,
920
921
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
922
        *,
923
924
925
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
926
        pooling_task: PoolingTask | None = None,
927
        tokenization_kwargs: dict[str, Any] | None = None,
928
    ) -> list[PoolingRequestOutput]:
929
930
        """Apply pooling to the hidden states corresponding to the input
        prompts.
931

932
        This class automatically batches the given prompts, considering
933
934
935
936
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
937
            prompts: The prompts to the LLM. You may pass a sequence of prompts
938
                for batch inference. See [PromptType][vllm.inputs.PromptType]
939
                for more details about the format of each prompt.
940
941
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
942
943
944
945
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
946
            lora_request: LoRA request to use for generation, if any.
947
            pooling_task: Override the pooling task to use.
948
949
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
950
951

        Returns:
952
            A list of `PoolingRequestOutput` objects containing the
953
            pooled hidden states in the same order as the input prompts.
954
955

        Note:
956
            Using `prompts` and `prompt_token_ids` as keyword parameters is
957
            considered legacy and may be deprecated in the future. You should
958
            instead pass them via the `inputs` parameter.
959
        """
960

961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
        error_str = (
            "pooling_task required for `LLM.encode`\n"
            "Please use one of the more specific methods or set the "
            "pooling_task when using `LLM.encode`:\n"
            "  - For embeddings, use `LLM.embed(...)` "
            'or `pooling_task="embed"`.\n'
            "  - For classification logits, use `LLM.classify(...)` "
            'or `pooling_task="classify"`.\n'
            "  - For similarity scores, use `LLM.score(...)`.\n"
            "  - For rewards, use `LLM.reward(...)` "
            'or `pooling_task="token_classify"`\n'
            "  - For token classification, "
            'use `pooling_task="token_classify"`\n'
            '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
        )
976

977
        if pooling_task is None:
978
            raise ValueError(error_str)
979

980
        model_config = self.model_config
981
        runner_type = model_config.runner_type
982
        if runner_type != "pooling":
983
984
985
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
986
987
                "pooling model."
            )
988

989
990
991
992
993
994
995
996
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
997
998
                    "offline inference example for more details."
                )
999
1000
1001
1002
1003
1004

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1005

1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens

1034
        self._validate_and_add_requests(
1035
            prompts=prompts,
1036
            params=pooling_params,
1037
            use_tqdm=use_tqdm,
1038
            lora_request=lora_request,
1039
            tokenization_kwargs=tokenization_kwargs,
1040
1041
        )

1042
        outputs = self._run_engine(use_tqdm=use_tqdm)
1043
1044

        model_outputs = self.engine_class.validate_outputs(
1045
1046
            outputs, PoolingRequestOutput
        )
1047
1048
1049
1050
1051

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1052
1053
                model_output=model_outputs
            )
1054
1055

            return [
1056
1057
1058
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1059
1060
1061
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1062
1063
1064
                    prompt_token_ids=[],
                    finished=True,
                )
1065
1066
1067
            ]
        else:
            return model_outputs
1068

1069
1070
    def embed(
        self,
1071
        prompts: PromptType | Sequence[PromptType],
1072
        *,
1073
1074
1075
1076
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1077
        tokenization_kwargs: dict[str, Any] | None = None,
1078
    ) -> list[EmbeddingRequestOutput]:
1079
1080
1081
1082
1083
1084
1085
1086
1087
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1088
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1089
                for more details about the format of each prompt.
1090
1091
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1092
1093
1094
1095
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1096
1097
1098
            lora_request: LoRA request to use for generation, if any.

        Returns:
1099
            A list of `EmbeddingRequestOutput` objects containing the
1100
1101
            embedding vectors in the same order as the input prompts.
        """
1102
        if "embed" not in self.supported_tasks:
1103
1104
            raise ValueError(
                "Embedding API is not supported by this model. "
1105
1106
                "Try converting the model using `--convert embed`."
            )
1107

1108
1109
1110
1111
1112
1113
1114
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1115
            tokenization_kwargs=tokenization_kwargs,
1116
        )
1117
1118
1119
1120
1121

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1122
        prompts: PromptType | Sequence[PromptType],
1123
        *,
1124
1125
1126
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1127
        tokenization_kwargs: dict[str, Any] | None = None,
1128
    ) -> list[ClassificationRequestOutput]:
1129
1130
1131
1132
1133
1134
1135
1136
1137
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1138
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1139
                for more details about the format of each prompt.
1140
1141
1142
1143
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1144
            lora_request: LoRA request to use for generation, if any.
1145
1146
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1147
        Returns:
1148
            A list of `ClassificationRequestOutput` objects containing the
1149
1150
            embedding vectors in the same order as the input prompts.
        """
1151
        if "classify" not in self.supported_tasks:
1152
            raise ValueError(
1153
                "Classification API is not supported by this model. "
1154
1155
                "Try converting the model using `--convert classify`."
            )
1156

1157
1158
1159
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1160
            pooling_params=pooling_params,
1161
1162
            lora_request=lora_request,
            pooling_task="classify",
1163
            tokenization_kwargs=tokenization_kwargs,
1164
        )
1165
1166
1167

        return [ClassificationRequestOutput.from_base(item) for item in items]

1168
1169
    def reward(
        self,
1170
        prompts: PromptType | Sequence[PromptType],
1171
1172
        /,
        *,
1173
1174
1175
1176
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1177
        tokenization_kwargs: dict[str, Any] | None = None,
1178
1179
1180
1181
1182
1183
1184
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1185
                for more details about the format of each prompt.
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1204
            pooling_task="token_classify",
1205
            tokenization_kwargs=tokenization_kwargs,
1206
1207
        )

1208
1209
    def _embedding_score(
        self,
1210
        tokenizer: TokenizerLike,
1211
1212
1213
1214
1215
1216
        text_1: list[str | TextPrompt | TokensPrompt],
        text_2: list[str | TextPrompt | TokensPrompt],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1217
        tokenization_kwargs: dict[str, Any] | None = None,
1218
1219
    ) -> list[ScoringRequestOutput]:
        encoded_output: list[PoolingRequestOutput] = self.encode(
1220
            text_1 + text_2,
1221
            truncate_prompt_tokens=truncate_prompt_tokens,
1222
1223
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1224
            pooling_params=pooling_params,
1225
            pooling_task="embed",
1226
            tokenization_kwargs=tokenization_kwargs,
1227
        )
1228

1229
1230
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]
1231
1232
1233
1234

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1235
1236
1237
        scores = _cosine_similarity(
            tokenizer=tokenizer, embed_1=encoded_output_1, embed_2=encoded_output_2
        )
1238

1239
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1240
1241
1242
1243
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1244
        tokenizer: TokenizerLike,
1245
1246
1247
1248
1249
1250
        data_1: list[str] | list[ScoreContentPartParam],
        data_2: list[str] | list[ScoreContentPartParam],
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1251
        tokenization_kwargs: dict[str, Any] | None = None,
1252
        score_template: str | None = None,
1253
    ) -> list[ScoringRequestOutput]:
1254
        model_config = self.model_config
1255
1256

        if isinstance(tokenizer, MistralTokenizer):
1257
            raise ValueError("Score API is not supported for Mistral tokenizer")
1258

1259
1260
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1261

1262
1263
1264
1265
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1266
        pooling_params_list = list[PoolingParams]()
1267

1268
1269
        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1270

1271
1272
1273
        _validate_truncation_size(
            model_config.max_model_len, truncate_prompt_tokens, tokenization_kwargs
        )
1274

1275
        prompts = list[PromptType]()
1276

1277
1278
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1279
1280
1281
1282
1283
1284
1285
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1286
                score_template=score_template,
1287
            )
1288

1289
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1290
1291
1292
1293
1294
1295
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)
1296

1297
            prompts.append(engine_prompt)
1298
1299

        self._validate_and_add_requests(
1300
            prompts=prompts,
1301
            params=pooling_params_list,
1302
            use_tqdm=use_tqdm,
1303
1304
1305
1306
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1307
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1308
1309
1310

        return [ScoringRequestOutput.from_base(item) for item in items]

1311
1312
    def score(
        self,
1313
1314
        data_1: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
        data_2: SingletonPrompt | Sequence[SingletonPrompt] | ScoreMultiModalParam,
1315
        /,
1316
        *,
1317
1318
1319
1320
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1321
        chat_template: str | None = None,
1322
    ) -> list[ScoringRequestOutput]:
1323
1324
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1325

1326
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1327
1328
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1329
        The input pairs are used to build a list of prompts for the
1330
1331
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1332
1333
1334
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1335
        appropriate multi-modal models. For multi-modal inputs, ensure the
1336
        prompt structure matches the model's expected input format.
1337
1338

        Args:
1339
1340
1341
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1342
                the `data_2` list.
1343
            data_2: The data to pair with the query to form the input to
1344
                the LLM. Can be text or multi-modal data. See [PromptType]
1345
                [vllm.inputs.PromptType] for more details about the format of
1346
                each prompt.
1347
1348
1349
1350
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1351
            lora_request: LoRA request to use for generation, if any.
1352
1353
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1354
1355
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1356
        Returns:
1357
            A list of `ScoringRequestOutput` objects containing the
1358
1359
            generated scores in the same order as the input prompts.
        """
1360
        model_config = self.model_config
1361
        runner_type = model_config.runner_type
1362
        if runner_type != "pooling":
1363
1364
1365
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1366
1367
                "pooling model."
            )
1368

1369
1370
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1371
1372
1373
1374
1375
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1376

1377
1378
1379
1380
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1381
            raise ValueError("Score API is only enabled for num_labels == 1.")
1382

1383
1384
1385
1386
1387
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1388
1389
1390
        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1391
        tokenizer = self.get_tokenizer()
1392

1393
        if not model_config.is_multimodal_model:
1394

1395
            def check_data_type(
1396
1397
1398
                data: SingletonPrompt
                | Sequence[SingletonPrompt]
                | ScoreMultiModalParam,
1399
            ):
1400
                if isinstance(data, dict) and "content" in data:
1401
1402
1403
1404
                    raise ValueError(
                        "ScoreMultiModalParam is not supported "
                        f"for {model_config.architecture}"
                    )
1405
1406
1407
1408
1409
1410
1411

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
1412
1413
1414
                        raise ValueError(
                            "Multi-modal prompt is not supported for scoring"
                        )
1415
1416
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
1417
1418
                            cast(TokensPrompt, prompt)["prompt_token_ids"]
                        )
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]
1433

1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1447

1448
        if model_config.is_cross_encoder:
1449
1450
1451
1452
1453
1454
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1455
                pooling_params,
1456
                lora_request,
1457
                score_template=chat_template,
1458
            )
1459
        else:
1460
1461
            return self._embedding_score(
                tokenizer,
1462
1463
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1464
1465
                truncate_prompt_tokens,
                use_tqdm,
1466
                pooling_params,
1467
1468
                lora_request,
            )
1469

1470
1471
1472
1473
1474
1475
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1476
1477
1478
1479
1480
1481
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1482

1483
1484
1485
1486
1487
1488
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1489
        Args:
1490
1491
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1492
                is forgotten. Level 1 sleep is good for sleeping and waking
1493
1494
1495
1496
1497
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1498
                sleep is good for sleeping and waking up the engine to run a
1499
                different model or update the model, where previous model
1500
                weights are not needed. It reduces CPU memory pressure.
1501
        """
1502
        self.reset_prefix_cache()
1503
1504
        self.llm_engine.sleep(level=level)

1505
    def wake_up(self, tags: list[str] | None = None):
1506
        """
1507
1508
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1509

1510
        Args:
1511
1512
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1513
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1514
                wake_up should be called with all tags (or None) before the
1515
1516
1517
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1518

1519
1520
1521
1522
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1523
            A `MetricSnapshot` instance capturing the current state
1524
1525
1526
1527
1528
1529
1530
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1531
1532
    def _validate_and_add_requests(
        self,
1533
1534
1535
1536
1537
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1538
        *,
1539
1540
1541
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None,
        priority: list[int] | None = None,
1542
        tokenization_kwargs: dict[str, Any] | None = None,
1543
    ) -> None:
1544
        if isinstance(prompts, (str, dict)):
1545
            # Convert a single prompt to a list.
1546
            prompts = [prompts]  # type: ignore[list-item]
1547

1548
        num_requests = len(prompts)
1549
        if isinstance(params, Sequence) and len(params) != num_requests:
1550
1551
1552
1553
1554
            raise ValueError("The lengths of prompts and params must be the same.")
        if isinstance(lora_request, Sequence) and len(lora_request) != num_requests:
            raise ValueError(
                "The lengths of prompts and lora_request must be the same."
            )
1555
1556
1557
1558
1559
1560
        if priority is not None and len(priority) != num_requests:
            raise ValueError(
                "The lengths of prompts "
                f"({num_requests}) and priority ({len(priority)}) "
                "must be the same."
            )
1561
1562

        for sp in params if isinstance(params, Sequence) else (params,):
1563
1564
1565
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1566

Zhuohan Li's avatar
Zhuohan Li committed
1567
        # Add requests to the engine.
1568
1569
        it = prompts
        if use_tqdm:
1570
1571
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1572

1573
        added_request_ids: list[str] = []
1574

1575
1576
1577
1578
1579
1580
1581
1582
1583
        try:
            for i, prompt in enumerate(it):
                request_id = self._add_request(
                    prompt,
                    params[i] if isinstance(params, Sequence) else params,
                    lora_request=lora_request[i]
                    if isinstance(lora_request, Sequence)
                    else lora_request,
                    priority=priority[i] if priority else 0,
1584
                    tokenization_kwargs=tokenization_kwargs,
1585
1586
1587
1588
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1589
                self.llm_engine.abort_request(added_request_ids, internal=True)
1590
            raise e
1591

1592
1593
1594
1595
    def _process_inputs(
        self,
        request_id: str,
        engine_prompt: PromptType,
1596
        params: SamplingParams | PoolingParams,
1597
        *,
1598
        lora_request: LoRARequest | None,
1599
        priority: int,
1600
        tokenization_kwargs: dict[str, Any] | None = None,
1601
1602
    ) -> tuple[EngineCoreRequest, dict[str, Any]]:
        """Use the Processor to process inputs for LLMEngine."""
1603
1604
1605

        local_kwargs = tokenization_kwargs or {}
        tokenization_kwargs = local_kwargs.copy()
1606
1607
1608
1609
1610
        _validate_truncation_size(
            self.model_config.max_model_len,
            params.truncate_prompt_tokens,
            tokenization_kwargs,
        )
1611

1612
        engine_request = self.input_processor.process_inputs(
1613
1614
1615
1616
1617
1618
1619
1620
            request_id,
            engine_prompt,
            params,
            lora_request=lora_request,
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )
        return engine_request, tokenization_kwargs
1621

1622
    def _add_request(
nunjunj's avatar
nunjunj committed
1623
        self,
1624
        prompt: PromptType,
1625
1626
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1627
        priority: int = 0,
1628
        tokenization_kwargs: dict[str, Any] | None = None,
1629
    ) -> str:
1630
        prompt_text, _, _ = get_prompt_components(prompt)
1631
        request_id = str(next(self.request_counter))
1632
1633

        engine_request, tokenization_kwargs = self._process_inputs(
1634
            request_id,
1635
            prompt,
1636
1637
            params,
            lora_request=lora_request,
1638
            priority=priority,
1639
            tokenization_kwargs=tokenization_kwargs,
1640
1641
1642
1643
1644
1645
1646
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1647
            tokenization_kwargs=tokenization_kwargs,
1648
            priority=priority,
1649
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1650
        )
1651
        return engine_request.request_id
1652

1653
    def _run_engine(
1654
1655
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1656
1657
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1658
            num_requests = self.llm_engine.get_num_unfinished_requests()
1659
1660
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1661
1662
1663
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1664
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1665
            )
1666

Zhuohan Li's avatar
Zhuohan Li committed
1667
        # Run the engine.
1668
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1669
1670
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1671
1672
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1673
            for output in step_outputs:
1674
                if output.finished:
1675
1676
                    outputs.append(output)
                    if use_tqdm:
1677
1678
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1679
                            n = len(output.outputs)
1680
                            assert output.prompt_token_ids is not None
1681
                            total_in_toks += len(output.prompt_token_ids) * n
1682
1683
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1684
1685
1686
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1687
1688
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1689
1690
                                f"output: {out_spd:.2f} toks/s"
                            )
1691
                            pbar.update(n)
1692
1693
                        else:
                            pbar.update(1)
1694
1695
                        if pbar.n == num_requests:
                            pbar.refresh()
1696

1697
1698
        if use_tqdm:
            pbar.close()
lizhigong's avatar
lizhigong committed
1699

1700
1701
1702
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1703
        return sorted(outputs, key=lambda x: int(x.request_id))
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716

    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr