llm.py 80.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Callable, Sequence
7
from typing import TYPE_CHECKING, Any, cast
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
44
45
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
46
from vllm.entrypoints.pooling.score.utils import (
47
    ScoreData,
48
49
50
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
51
    compute_maxsim_score,
52
    get_score_prompt,
53
    validate_score_input,
54
)
55
from vllm.entrypoints.utils import log_non_default_args
56
from vllm.inputs.data import (
57
58
59
60
61
62
    DataPrompt,
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
63
from vllm.logger import init_logger
64
from vllm.lora.request import LoRARequest
65
from vllm.model_executor.layers.quantization import QuantizationMethods
66
67
68
69
70
71
72
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
73
from vllm.platforms import current_platform
74
from vllm.pooling_params import PoolingParams
75
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
76
from vllm.renderers.inputs import DictPrompt, TokPrompt
77
78
79
80
81
82
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    extract_prompt_components,
    parse_model_prompt,
    prompt_to_seq,
)
83
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
84
from vllm.tasks import PoolingTask
85
86
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
87
from vllm.usage.usage_lib import UsageContext
88
from vllm.utils.counter import Counter
89
from vllm.v1.engine.llm_engine import LLMEngine
90
from vllm.v1.sample.logits_processor import LogitsProcessor
91

92
93
94
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

95
96
logger = init_logger(__name__)

97
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
98
99
_R = TypeVar("_R", default=Any)

100
101

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
107
108
109
110
111
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
112
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
113
114
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
115
116
117
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
118
119
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
120
121
122
123
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
124
        allowed_media_domains: If set, only media URLs that belong to this
125
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
129
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
130
131
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
132
        quantization: The method used to quantize the model weights. Currently,
133
            we support "awq", "gptq", and "fp8" (experimental).
134
135
136
137
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
138
139
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
140
141
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
142
143
144
145
146
147
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
148
149
150
151
152
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
153
            compared with using gpu_memory_utilization. Note that
154
155
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
156
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
157
158
159
160
161
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
162
163
164
165
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
166
167
168
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
169
        enable_return_routed_experts: Whether to return routed experts.
170
171
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
172
        hf_token: The token to use as HTTP bearer authorization for remote files
173
            . If `True`, will use the token generated when running
174
            `hf auth login` (stored in `~/.cache/huggingface/token`).
175
176
177
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
178
179
180
181
182
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
183
184
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
185
        compilation_config: Either an integer or a dictionary. If it is an
186
            integer, it is used as the mode of compilation optimization. If it
187
            is a dictionary, it can specify the full compilation configuration.
188
189
190
191
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
192
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
193

194
195
    Note:
        This class is intended to be used for offline inference. For online
196
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
197
    """
198
199
200
201

    def __init__(
        self,
        model: str,
202
        *,
203
204
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
205
        tokenizer: str | None = None,
206
        tokenizer_mode: TokenizerMode | str = "auto",
207
        skip_tokenizer_init: bool = False,
208
        trust_remote_code: bool = False,
209
        allowed_local_media_path: str = "",
210
        allowed_media_domains: list[str] | None = None,
211
        tensor_parallel_size: int = 1,
212
        dtype: ModelDType = "auto",
213
214
215
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
216
        seed: int = 0,
217
        gpu_memory_utilization: float = 0.9,
218
        swap_space: float = 4,
219
        cpu_offload_gb: float = 0,
220
        enforce_eager: bool = False,
221
        enable_return_routed_experts: bool = False,
222
        disable_custom_all_reduce: bool = False,
223
224
225
226
227
228
229
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
230
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
231
        attention_config: dict[str, Any] | AttentionConfig | None = None,
232
233
234
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
235
        **kwargs: Any,
236
    ) -> None:
237
        """LLM constructor."""
238

239
240
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
241

242
243
244
245
246
247
248
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

249
        if "kv_transfer_config" in kwargs and isinstance(
250
251
            kwargs["kv_transfer_config"], dict
        ):
252
            from vllm.config.kv_transfer import KVTransferConfig
253

254
255
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
256
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
257
258
259
260
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
261
262
263
                    raw_config_dict,
                    e,
                )
264
265
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
266
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
267

268
269
270
        if hf_overrides is None:
            hf_overrides = {}

271
272
273
274
275
276
277
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
278

279
280
281
282
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
283
        else:
284
285
286
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
287

288
289
290
291
292
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
293

294
        # warn about single-process data parallel usage.
295
296
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
297
298
299
300
301
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
302
            raise ValueError(
303
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
304
305
306
307
308
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
309
        engine_args = EngineArgs(
310
            model=model,
311
312
            runner=runner,
            convert=convert,
313
            tokenizer=tokenizer,
314
            tokenizer_mode=tokenizer_mode,
315
            skip_tokenizer_init=skip_tokenizer_init,
316
            trust_remote_code=trust_remote_code,
317
            allowed_local_media_path=allowed_local_media_path,
318
            allowed_media_domains=allowed_media_domains,
319
320
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
321
            quantization=quantization,
322
            revision=revision,
323
            tokenizer_revision=tokenizer_revision,
324
325
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
326
            kv_cache_memory_bytes=kv_cache_memory_bytes,
327
            swap_space=swap_space,
328
            cpu_offload_gb=cpu_offload_gb,
329
            enforce_eager=enforce_eager,
330
            enable_return_routed_experts=enable_return_routed_experts,
331
            disable_custom_all_reduce=disable_custom_all_reduce,
332
            hf_token=hf_token,
333
            hf_overrides=hf_overrides,
334
            mm_processor_kwargs=mm_processor_kwargs,
335
            pooler_config=pooler_config,
336
            structured_outputs_config=structured_outputs_instance,
337
            profiler_config=profiler_config_instance,
338
            attention_config=attention_config_instance,
339
            compilation_config=compilation_config_instance,
340
            logits_processors=logits_processors,
341
342
            **kwargs,
        )
343

344
345
        log_non_default_args(engine_args)

346
        self.llm_engine = LLMEngine.from_engine_args(
347
348
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
349
        self.engine_class = type(self.llm_engine)
350

351
        self.request_counter = Counter()
352
        self.default_sampling_params: dict[str, Any] | None = None
353

354
355
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
356
357
        self.supported_tasks = supported_tasks

358
        self.model_config = self.llm_engine.model_config
359
        self.input_processor = self.llm_engine.input_processor
360
        self.io_processor = self.llm_engine.io_processor
361

362
363
364
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

365
    def get_tokenizer(self) -> TokenizerLike:
366
        return self.llm_engine.get_tokenizer()
367

368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

385
    def reset_mm_cache(self) -> None:
386
        self.input_processor.clear_mm_cache()
387
388
        self.llm_engine.reset_mm_cache()

389
    def get_default_sampling_params(self) -> SamplingParams:
390
        if self.default_sampling_params is None:
391
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
392
393
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
394
395
        return SamplingParams()

396
397
    def generate(
        self,
398
399
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
400
        *,
401
402
403
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
404
        tokenization_kwargs: dict[str, Any] | None = None,
405
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
406
407
        """Generates the completions for the input prompts.

408
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
409
410
411
412
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
413
            prompts: The prompts to the LLM. You may pass a sequence of prompts
414
                for batch inference. See [PromptType][vllm.inputs.PromptType]
415
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
416
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
417
418
419
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
420
                prompts and it is paired one by one with the prompt.
421
422
423
424
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
425
            lora_request: LoRA request to use for generation, if any.
426
427
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
428
429
430
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
431
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
432
433

        Returns:
434
            A list of `RequestOutput` objects containing the
435
436
            generated completions in the same order as the input prompts.
        """
437
        model_config = self.model_config
438
439
        runner_type = model_config.runner_type
        if runner_type != "generate":
440
441
442
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
443
444
                "generative model."
            )
445

446
        if sampling_params is None:
447
            sampling_params = self.get_default_sampling_params()
448

449
        outputs = self._run_completion(
450
            prompts=prompts,
451
            params=sampling_params,
452
            use_tqdm=use_tqdm,
453
            lora_request=lora_request,
454
            tokenization_kwargs=tokenization_kwargs,
455
456
            priority=priority,
        )
457

Joe Runde's avatar
Joe Runde committed
458
        return self.engine_class.validate_outputs(outputs, RequestOutput)
459

460
    def _get_modality_specific_lora_reqs(
461
        self,
462
        prompts: Sequence[DictPrompt | TokPrompt],
463
        lora_request: list[LoRARequest] | LoRARequest | None,
464
    ):
465
466
467
468
469
470
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
471
472
        if (
            lora_config is None
473
            or not self.model_config.is_multimodal_model
474
475
            or (lora_config and lora_config.default_mm_loras is None)
        ):
476
477
            return lora_request

478
479
480
481
482
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
483
484
485

        return [
            self._resolve_single_prompt_mm_lora(
486
                prompt,
487
488
                opt_lora_req,
                lora_config.default_mm_loras,
489
490
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
491
492
        ]

493
494
    def _resolve_single_prompt_mm_lora(
        self,
495
        prompt: DictPrompt | TokPrompt,
496
497
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
498
    ):
499
500
        if not default_mm_loras or not (
            mm_data := prompt.get("multi_modal_data") or {}
501
        ):
502
503
            return lora_request

504
505
506
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
507
508
509
510
511
512
513
514
515
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
516
517
518
                " will be skipped",
                intersection,
            )
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
534
535
                    "lora_request as we only apply one LoRARequest per prompt"
                )
536
537
538
539
540
541
542
543
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

544
545
    def collective_rpc(
        self,
546
547
        method: str | Callable[..., _R],
        timeout: float | None = None,
548
        args: tuple = (),
549
        kwargs: dict[str, Any] | None = None,
550
    ) -> list[_R]:
551
552
553
554
555
556
557
558
559
560
561
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
562
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
563
564
565
566
567
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
568

569
570
571
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
572
        """
573
574

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
575
576

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
577
        """
578
579
        Run a function directly on the model inside each worker,
        returning the result for each of them.
580
581
582
583
584
585

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
586
        """
587
        return self.llm_engine.apply_model(func)
588

589
590
    def _get_beam_search_lora_requests(
        self,
591
592
593
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
594
        """Get the optional lora request corresponding to each prompt."""
595
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
596
            raise ValueError(
597
598
                "Lora request list should be the same length as the prompts"
            )
599
600
601
602
603
604

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

605
606
    def beam_search(
        self,
607
        prompts: list[TokensPrompt | TextPrompt],
608
        params: BeamSearchParams,
609
        lora_request: list[LoRARequest] | LoRARequest | None = None,
610
        use_tqdm: bool = False,
611
        concurrency_limit: int | None = None,
612
    ) -> list[BeamSearchOutput]:
613
614
615
616
617
618
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
619
            params: The beam search parameters.
620
            lora_request: LoRA request to use for generation, if any.
621
            use_tqdm: Whether to use tqdm to display the progress bar.
622
623
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
624
        """
625
626
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
627
628
629
630
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
631
632
        length_penalty = params.length_penalty

633
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
634

635
636
637
638
639
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
640

641
642
643
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
644
645
                "Disabling progress bar."
            )
646
647
648
649
650
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

651
652
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
653
654
655
656
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
657
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
658
            return TokensPrompt(**token_prompt_kwargs)
659

660
661
662
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
663
        beam_search_params = SamplingParams(
664
665
666
667
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
668
        )
669
        instances: list[BeamSearchInstance] = []
670

671
        for lora_req, prompt in zip(lora_requests, prompts):
672
673
674
675
676
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
677
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
678

679
680
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
681
682
683
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
684

685
            instances.append(
686
687
688
689
690
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
691
692
                ),
            )
693

694
        for prompt_start in range(0, len(prompts), concurrency_limit):
695
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
696
697
698

            token_iter = range(max_tokens)
            if use_tqdm:
699
700
701
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
702
703
704
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
705
706
                    "reflect instance-level progress."
                )
707
708
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
709
710
                    sum((instance.beams for instance in instances_batch), [])
                )
711
712
                pos = [0] + list(
                    itertools.accumulate(
713
714
715
                        len(instance.beams) for instance in instances_batch
                    )
                )
716
                instance_start_and_end: list[tuple[int, int]] = list(
717
718
                    zip(pos[:-1], pos[1:])
                )
719
720
721
722
723
724

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
725
726
727
728
729
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
730
731
732

                # only runs for one step
                # we don't need to use tqdm here
733
734
735
736
737
738
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
739

740
741
742
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
743
744
745
746
747
748
749
750
751
752
753
754
755
756
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
757
                                    logprobs=current_beam.logprobs + [logprobs],
758
                                    lora_request=current_beam.lora_request,
759
760
761
762
763
764
765
766
767
768
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
769
770
771
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
772
773
774
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
775
                    instance.beams = sorted_beams[:beam_width]
776
777
778
779

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
780
781
782
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
783
784
785
786
787
788
789
790
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
    def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            add_special_tokens=not model_config.is_encoder_decoder,
        ).with_kwargs(tokenization_kwargs)

    def _preprocess_completion(
        self,
806
        prompts: Sequence[PromptType],
807
        tokenization_kwargs: dict[str, Any] | None = None,
808
    ) -> Sequence[DictPrompt | TokPrompt]:
809
810
811
812
813
814
815
816
817
818
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
        """
819
820
821
        renderer = self.llm_engine.renderer
        model_config = self.model_config

822
823
824
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
825
826
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

827
        return renderer.render_cmpl(parsed_prompts, tok_params)
828
829
830
831
832
833
834
835
836
837
838
839
840

    def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        ).with_kwargs(tokenization_kwargs)

    def _preprocess_chat(
        self,
841
        conversations: Sequence[list[ChatCompletionMessageParam]],
842
        chat_template: str | None = None,
843
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
844
        chat_template_kwargs: dict[str, Any] | None = None,
845
        add_generation_prompt: bool = True,
846
        continue_final_message: bool = False,
847
        tools: list[dict[str, Any]] | None = None,
848
        tokenization_kwargs: dict[str, Any] | None = None,
849
        mm_processor_kwargs: dict[str, Any] | None = None,
850
    ) -> Sequence[TokPrompt]:
nunjunj's avatar
nunjunj committed
851
        """
852
853
854
855
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
856
857

        Returns:
858
859
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
nunjunj's avatar
nunjunj committed
860
        """
861
        renderer = self.llm_engine.renderer
862

863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
                ),
            ),
        )
        tok_params = self._get_chat_tok_params(tokenization_kwargs)

878
879
880
881
882
883
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
884

885
        return engine_prompts
886
887
888

    def chat(
        self,
889
        messages: list[ChatCompletionMessageParam]
890
891
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
892
893
894
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
895
896
897
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
898
899
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
900
        tokenization_kwargs: dict[str, Any] | None = None,
901
        mm_processor_kwargs: dict[str, Any] | None = None,
902
903
904
905
906
907
908
909
910
911
912
913
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
914
            messages: A sequence of conversations or a single conversation.
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
946
947
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
948
949
950
951
952

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

        outputs = self._run_chat(
            messages=messages,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
970
971
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
972
            chat_template_kwargs=chat_template_kwargs,
973
974
975
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
976
            tokenization_kwargs=tokenization_kwargs,
977
978
979
            mm_processor_kwargs=mm_processor_kwargs,
        )

980
        return self.engine_class.validate_outputs(outputs, RequestOutput)
nunjunj's avatar
nunjunj committed
981

982
983
    def encode(
        self,
984
985
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
986
        *,
987
988
989
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
990
        pooling_task: PoolingTask | None = None,
991
        tokenization_kwargs: dict[str, Any] | None = None,
992
    ) -> list[PoolingRequestOutput]:
993
994
        """Apply pooling to the hidden states corresponding to the input
        prompts.
995

996
        This class automatically batches the given prompts, considering
997
998
999
1000
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1001
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1002
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1003
                for more details about the format of each prompt.
1004
1005
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1006
1007
1008
1009
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1010
            lora_request: LoRA request to use for generation, if any.
1011
            pooling_task: Override the pooling task to use.
1012
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1013
1014

        Returns:
1015
            A list of `PoolingRequestOutput` objects containing the
1016
            pooled hidden states in the same order as the input prompts.
1017
        """
1018

1019
        if pooling_task is None:
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1035

1036
        model_config = self.model_config
1037
        runner_type = model_config.runner_type
1038
        if runner_type != "pooling":
1039
1040
1041
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1042
1043
                "pooling model."
            )
1044

1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1059
        if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
1060
1061
1062
1063
1064
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1065
1066
                    "offline inference example for more details."
                )
1067
1068

            # Validate the request data is valid for the loaded plugin
1069
            validated_prompt = self.io_processor.parse_data(prompts)
1070
1071
1072

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1073
            prompts_seq = prompt_to_seq(prompts)
1074

1075
1076
1077
1078
1079
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1080
                )
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1101

1102
        outputs = self._run_completion(
1103
1104
            prompts=prompts_seq,
            params=params_seq,
1105
            use_tqdm=use_tqdm,
1106
            lora_request=lora_request,
1107
            tokenization_kwargs=tokenization_kwargs,
1108
1109
        )

1110
        model_outputs = self.engine_class.validate_outputs(
1111
1112
            outputs, PoolingRequestOutput
        )
1113

1114
        if use_io_processor:
1115
1116
            # get the post-processed model outputs
            assert self.io_processor is not None
1117
            processed_outputs = self.io_processor.post_process(model_outputs)
1118
1119

            return [
1120
1121
1122
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1123
1124
1125
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1126
1127
1128
                    prompt_token_ids=[],
                    finished=True,
                )
1129
1130
1131
            ]
        else:
            return model_outputs
1132

1133
1134
    def embed(
        self,
1135
        prompts: PromptType | Sequence[PromptType],
1136
        *,
1137
1138
1139
1140
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1141
        tokenization_kwargs: dict[str, Any] | None = None,
1142
    ) -> list[EmbeddingRequestOutput]:
1143
1144
1145
1146
1147
1148
1149
1150
1151
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1152
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1153
                for more details about the format of each prompt.
1154
1155
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1156
1157
1158
1159
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1160
            lora_request: LoRA request to use for generation, if any.
1161
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1162
1163

        Returns:
1164
            A list of `EmbeddingRequestOutput` objects containing the
1165
1166
            embedding vectors in the same order as the input prompts.
        """
1167
        if "embed" not in self.supported_tasks:
1168
1169
            raise ValueError(
                "Embedding API is not supported by this model. "
1170
1171
                "Try converting the model using `--convert embed`."
            )
1172

1173
1174
1175
1176
1177
1178
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1179
1180
1181
1182
1183
1184
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1185
            tokenization_kwargs=tokenization_kwargs,
1186
        )
1187
1188
1189
1190
1191

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1192
        prompts: PromptType | Sequence[PromptType],
1193
        *,
1194
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1195
        use_tqdm: bool | Callable[..., tqdm] = True,
1196
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1197
        tokenization_kwargs: dict[str, Any] | None = None,
1198
    ) -> list[ClassificationRequestOutput]:
1199
1200
1201
1202
1203
1204
1205
1206
1207
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1208
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1209
                for more details about the format of each prompt.
1210
1211
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1212
1213
1214
1215
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1216
            lora_request: LoRA request to use for generation, if any.
1217
1218
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1219
        Returns:
1220
            A list of `ClassificationRequestOutput` objects containing the
1221
1222
            embedding vectors in the same order as the input prompts.
        """
1223
        if "classify" not in self.supported_tasks:
1224
            raise ValueError(
1225
                "Classification API is not supported by this model. "
1226
1227
                "Try converting the model using `--convert classify`."
            )
1228

1229
1230
1231
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1232
            pooling_params=pooling_params,
1233
1234
            lora_request=lora_request,
            pooling_task="classify",
1235
            tokenization_kwargs=tokenization_kwargs,
1236
        )
1237
1238
1239

        return [ClassificationRequestOutput.from_base(item) for item in items]

1240
1241
    def reward(
        self,
1242
        prompts: PromptType | Sequence[PromptType],
1243
1244
        /,
        *,
1245
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1246
1247
1248
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1249
        tokenization_kwargs: dict[str, Any] | None = None,
1250
1251
1252
1253
1254
1255
1256
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1257
                for more details about the format of each prompt.
1258
1259
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1260
1261
1262
1263
1264
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1265
1266
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1278
            pooling_task="token_classify",
1279
            tokenization_kwargs=tokenization_kwargs,
1280
1281
        )

1282
1283
    def _embedding_score(
        self,
1284
1285
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1286
1287
1288
1289
1290
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1291
    ) -> list[ScoringRequestOutput]:
1292
1293
        tokenizer = self.get_tokenizer()

1294
1295
1296
1297
1298
1299
1300
1301
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1302
        encoded_output = self.encode(
1303
            input_texts,
1304
1305
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1306
            pooling_params=pooling_params,
1307
            pooling_task="embed",
1308
            tokenization_kwargs=tokenization_kwargs,
1309
        )
1310

1311
1312
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1313
1314
1315
1316

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1317
        scores = _cosine_similarity(
1318
1319
1320
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1321
        )
1322

1323
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1324
1325
        return [ScoringRequestOutput.from_base(item) for item in items]

1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

        # Extract text from ScoreData
        text_1: list[str] = []
        for text in data_1:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_1.append(text)

        text_2: list[str] = []
        for text in data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_2.append(text)

        encoded_output: list[PoolingRequestOutput] = self.encode(
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

1407
1408
    def _cross_encoding_score(
        self,
1409
1410
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1411
1412
1413
1414
1415
1416
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1417
    ) -> list[ScoringRequestOutput]:
1418
        model_config = self.model_config
1419
        tokenizer = self.get_tokenizer()
1420
1421

        if isinstance(tokenizer, MistralTokenizer):
1422
            raise ValueError("Score API is not supported for Mistral tokenizer")
1423

1424
1425
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1426

1427
1428
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1429
1430
        elif pooling_params.task is None:
            pooling_params.task = "score"
1431

1432
        pooling_params_list = list[PoolingParams]()
1433

1434
        prompts = list[PromptType]()
1435

1436
1437
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1438
1439
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1440
                model_config=model_config,
1441
1442
1443
1444
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1445
                score_template=score_template,
1446
1447
            )

1448
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1449
1450
1451
1452
1453
1454
1455
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1456
            prompts.append(engine_prompt)
1457

1458
        outputs = self._run_completion(
1459
            prompts=prompts,
1460
            params=pooling_params_list,
1461
            use_tqdm=use_tqdm,
1462
1463
1464
            lora_request=lora_request,
        )

1465
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1466
1467
1468

        return [ScoringRequestOutput.from_base(item) for item in items]

1469
1470
    def score(
        self,
1471
1472
1473
1474
1475
1476
1477
1478
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1479
        /,
1480
        *,
1481
1482
1483
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1484
        tokenization_kwargs: dict[str, Any] | None = None,
1485
        chat_template: str | None = None,
1486
    ) -> list[ScoringRequestOutput]:
1487
1488
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1489

1490
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1491
1492
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1493
        The input pairs are used to build a list of prompts for the
1494
1495
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1496
1497
1498
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1499
        appropriate multi-modal models. For multi-modal inputs, ensure the
1500
        prompt structure matches the model's expected input format.
1501
1502

        Args:
1503
1504
1505
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1506
                the `data_2` list.
1507
            data_2: The data to pair with the query to form the input to
1508
                the LLM. Can be text or multi-modal data. See [PromptType]
1509
                [vllm.inputs.PromptType] for more details about the format of
1510
                each prompt.
1511
1512
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1513
1514
1515
1516
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1517
            lora_request: LoRA request to use for generation, if any.
1518
1519
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1520
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1521
        Returns:
1522
            A list of `ScoringRequestOutput` objects containing the
1523
1524
            generated scores in the same order as the input prompts.
        """
1525
        model_config = self.model_config
1526

1527
        runner_type = model_config.runner_type
1528
        if runner_type != "pooling":
1529
1530
1531
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1532
1533
                "pooling model."
            )
1534

1535
        supported_tasks = self.supported_tasks
1536
1537
1538
1539
1540
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1541
1542
1543
1544
1545
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1546

1547
1548
1549
1550
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1551
            raise ValueError("Score API is only enabled for num_labels == 1.")
1552

1553
1554
1555
1556
1557
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1558
1559
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1560

1561
1562
1563
1564
1565
1566
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1567

1568
1569
1570
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
        encode_kwargs = tok_params.get_encode_kwargs()

1571
        if model_config.is_cross_encoder:
1572
            return self._cross_encoding_score(
1573
1574
                score_data_1,
                score_data_2,
1575
1576
1577
1578
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1579
                score_template=chat_template,
1580
            )
1581
1582
1583
1584
1585
1586
1587
1588
1589
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1590
        else:
1591
            return self._embedding_score(
1592
1593
                score_data_1,
                score_data_2,
1594
1595
1596
1597
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1598
            )
1599

1600
1601
1602
1603
1604
1605
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1606
1607
1608
1609
1610
1611
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1612

1613
1614
1615
1616
1617
1618
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1619
        Args:
1620
1621
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1622
                is forgotten. Level 1 sleep is good for sleeping and waking
1623
1624
1625
1626
1627
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1628
                sleep is good for sleeping and waking up the engine to run a
1629
                different model or update the model, where previous model
1630
                weights are not needed. It reduces CPU memory pressure.
1631
        """
1632
        self.reset_prefix_cache()
1633
1634
        self.llm_engine.sleep(level=level)

1635
    def wake_up(self, tags: list[str] | None = None):
1636
        """
1637
1638
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1639

1640
        Args:
1641
1642
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1643
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1644
                wake_up should be called with all tags (or None) before the
1645
1646
1647
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1648

1649
1650
1651
1652
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1653
            A `MetricSnapshot` instance capturing the current state
1654
1655
1656
1657
1658
1659
1660
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1661
    def _params_to_seq(
1662
        self,
1663
        params: _P | Sequence[_P],
1664
        num_requests: int,
1665
    ) -> Sequence[_P]:
1666
1667
1668
1669
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1670
                    f"and params ({len(params)}) must be the same."
1671
1672
                )

1673
            return params
1674

1675
1676
1677
1678
1679
1680
1681
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1682
1683
1684
1685
1686
1687
1688
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1689
1690
1691
            return lora_request

        return [lora_request] * num_requests
1692

1693
1694
1695
1696
1697
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1698
1699
1700
1701
1702
1703
1704
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
            return priority

        return [0] * num_requests

    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))

        if any(param.truncate_prompt_tokens is not None for param in seq_params):
1725
1726
1727
            # TODO: Remove this after deprecating `param.truncate_prompt_tokens`
            # Then, move the code from the `else` block to the top and let
            # `self._preprocess_completion` handle prompt normalization
1728
            engine_prompts: Sequence[DictPrompt | TokPrompt] = [
1729
                engine_prompt
1730
                for prompt, param in zip(seq_prompts, seq_params)
1731
                for engine_prompt in self._preprocess_completion(
1732
                    [prompt],
1733
1734
1735
1736
1737
1738
1739
1740
                    tokenization_kwargs=merge_kwargs(
                        tokenization_kwargs,
                        dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
                    ),
                )
            ]
        else:
            engine_prompts = self._preprocess_completion(
1741
                seq_prompts,
1742
                tokenization_kwargs=tokenization_kwargs,
1743
            )
1744

1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
        self._validate_and_add_requests(
            prompts=engine_prompts,
            params=seq_params,
            use_tqdm=use_tqdm,
            lora_request=self._get_modality_specific_lora_reqs(
                engine_prompts, lora_request
            ),
            tokenization_kwargs=tokenization_kwargs,
            priority=priority,
        )

        return self._run_engine(use_tqdm=use_tqdm)

    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
        engine_prompts = self._preprocess_chat(
            conversation_to_seq(messages),
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        self._validate_and_add_requests(
            prompts=engine_prompts,
            params=params,
            use_tqdm=use_tqdm,
            lora_request=self._get_modality_specific_lora_reqs(
                engine_prompts, lora_request
            ),
            tokenization_kwargs=tokenization_kwargs,
        )

        return self._run_engine(use_tqdm=use_tqdm)

    def _validate_and_add_requests(
        self,
        prompts: Sequence[DictPrompt | TokPrompt],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any] | None = None,
        priority: list[int] | None = None,
    ) -> None:
        num_requests = len(prompts)
        seq_params = self._params_to_seq(params, num_requests)
        seq_lora_requests = self._lora_request_to_seq(lora_request, num_requests)
        seq_priority = self._priority_to_seq(priority, num_requests)

        for sp in seq_params:
1819
1820
1821
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1822

Zhuohan Li's avatar
Zhuohan Li committed
1823
        # Add requests to the engine.
1824
        it = prompts
1825
        if use_tqdm:
1826
1827
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1828

1829
        added_request_ids: list[str] = []
1830

1831
1832
1833
1834
        try:
            for i, prompt in enumerate(it):
                request_id = self._add_request(
                    prompt,
1835
1836
                    seq_params[i],
                    lora_request=seq_lora_requests[i],
1837
                    tokenization_kwargs=tokenization_kwargs,
1838
                    priority=seq_priority[i],
1839
1840
1841
1842
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1843
                self.llm_engine.abort_request(added_request_ids, internal=True)
1844
            raise e
1845

1846
    def _add_request(
nunjunj's avatar
nunjunj committed
1847
        self,
1848
        prompt: PromptType | DictPrompt | TokPrompt,
1849
1850
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1851
        tokenization_kwargs: dict[str, Any] | None = None,
1852
        priority: int = 0,
1853
    ) -> str:
1854
        prompt_text, _, _ = extract_prompt_components(self.model_config, prompt)
1855
        request_id = str(next(self.request_counter))
1856

1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )

        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

        tokenization_kwargs = tok_params.get_encode_kwargs()
        engine_request = self.input_processor.process_inputs(
1876
            request_id,
1877
            prompt,
1878
1879
            params,
            lora_request=lora_request,
1880
            tokenization_kwargs=tokenization_kwargs,
1881
            priority=priority,
1882
            supported_tasks=self.supported_tasks,
1883
1884
1885
1886
1887
1888
1889
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1890
            tokenization_kwargs=tokenization_kwargs,
1891
            priority=priority,
1892
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1893
        )
1894
        return engine_request.request_id
1895

1896
    def _run_engine(
1897
1898
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1899
1900
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1901
            num_requests = self.llm_engine.get_num_unfinished_requests()
1902
1903
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1904
1905
1906
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1907
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1908
            )
1909

Zhuohan Li's avatar
Zhuohan Li committed
1910
        # Run the engine.
1911
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1912
1913
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1914
1915
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1916
            for output in step_outputs:
1917
                if output.finished:
1918
1919
                    outputs.append(output)
                    if use_tqdm:
1920
1921
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1922
                            n = len(output.outputs)
1923
                            assert output.prompt_token_ids is not None
1924
                            total_in_toks += len(output.prompt_token_ids) * n
1925
1926
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1927
1928
1929
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1930
1931
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1932
1933
                                f"output: {out_spd:.2f} toks/s"
                            )
1934
                            pbar.update(n)
1935
1936
                        else:
                            pbar.update(1)
1937
1938
                        if pbar.n == num_requests:
                            pbar.refresh()
1939

1940
1941
        if use_tqdm:
            pbar.close()
1942
1943
1944
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1945
        return sorted(outputs, key=lambda x: int(x.request_id))
1946

1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr