llm.py 80.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Callable, Sequence
7
from typing import TYPE_CHECKING, Any, TypeAlias, cast
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
44
45
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
46
from vllm.entrypoints.pooling.score.utils import (
47
    ScoreData,
48
49
50
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
51
    compute_maxsim_score,
52
    get_score_prompt,
53
    validate_score_input,
54
)
55
from vllm.entrypoints.utils import log_non_default_args
56
57
from vllm.inputs import (
    DataPrompt,
58
59
    EmbedsPrompt,
    ExplicitEncoderDecoderPrompt,
60
61
62
63
64
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
65
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
66
from vllm.logger import init_logger
67
from vllm.lora.request import LoRARequest
68
from vllm.model_executor.layers.quantization import QuantizationMethods
69
70
71
72
73
74
75
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
76
from vllm.platforms import current_platform
77
from vllm.pooling_params import PoolingParams
78
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
79
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
80
from vllm.tasks import PoolingTask
81
82
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
83
from vllm.usage.usage_lib import UsageContext
84
from vllm.utils.collection_utils import as_iter, is_list_of
85
from vllm.utils.counter import Counter
86
from vllm.v1.engine.llm_engine import LLMEngine
87
from vllm.v1.sample.logits_processor import LogitsProcessor
88

89
90
91
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

92
93
logger = init_logger(__name__)

94
95
_R = TypeVar("_R", default=Any)

96
97
98
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]

99
100

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
101
102
103
104
105
106
107
108
109
110
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
111
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
112
113
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
114
115
116
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
117
118
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
119
120
121
122
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
123
        allowed_media_domains: If set, only media URLs that belong to this
124
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
125
126
127
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
129
130
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
131
        quantization: The method used to quantize the model weights. Currently,
132
            we support "awq", "gptq", and "fp8" (experimental).
133
134
135
136
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
137
138
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
139
140
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
141
142
143
144
145
146
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
147
148
149
150
151
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
152
            compared with using gpu_memory_utilization. Note that
153
154
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
155
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
156
157
158
159
160
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
161
162
163
164
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
165
166
167
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
168
        enable_return_routed_experts: Whether to return routed experts.
169
170
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
171
        hf_token: The token to use as HTTP bearer authorization for remote files
172
            . If `True`, will use the token generated when running
173
            `hf auth login` (stored in `~/.cache/huggingface/token`).
174
175
176
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
177
178
179
180
181
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
182
183
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
184
        compilation_config: Either an integer or a dictionary. If it is an
185
            integer, it is used as the mode of compilation optimization. If it
186
            is a dictionary, it can specify the full compilation configuration.
187
188
189
190
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
191
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
192

193
194
    Note:
        This class is intended to be used for offline inference. For online
195
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
196
    """
197
198
199
200

    def __init__(
        self,
        model: str,
201
        *,
202
203
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
204
        tokenizer: str | None = None,
205
        tokenizer_mode: TokenizerMode | str = "auto",
206
        skip_tokenizer_init: bool = False,
207
        trust_remote_code: bool = False,
208
        allowed_local_media_path: str = "",
209
        allowed_media_domains: list[str] | None = None,
210
        tensor_parallel_size: int = 1,
211
        dtype: ModelDType = "auto",
212
213
214
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
215
        seed: int = 0,
216
        gpu_memory_utilization: float = 0.9,
217
        swap_space: float = 4,
218
        cpu_offload_gb: float = 0,
219
        enforce_eager: bool = False,
220
        enable_return_routed_experts: bool = False,
221
        disable_custom_all_reduce: bool = False,
222
223
224
225
226
227
228
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
229
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
230
        attention_config: dict[str, Any] | AttentionConfig | None = None,
231
232
233
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
234
        **kwargs: Any,
235
    ) -> None:
236
        """LLM constructor."""
237

238
239
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
240

241
242
243
244
245
246
247
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

248
        if "kv_transfer_config" in kwargs and isinstance(
249
250
            kwargs["kv_transfer_config"], dict
        ):
251
            from vllm.config.kv_transfer import KVTransferConfig
252

253
254
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
255
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
256
257
258
259
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
260
261
262
                    raw_config_dict,
                    e,
                )
263
264
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
265
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
266

267
268
269
        if hf_overrides is None:
            hf_overrides = {}

270
271
272
273
274
275
276
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
277

278
279
280
281
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
282
        else:
283
284
285
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
286

287
288
289
290
291
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
292

293
        # warn about single-process data parallel usage.
294
295
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
296
297
298
299
300
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
301
            raise ValueError(
302
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
303
304
305
306
307
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
308
        engine_args = EngineArgs(
309
            model=model,
310
311
            runner=runner,
            convert=convert,
312
            tokenizer=tokenizer,
313
            tokenizer_mode=tokenizer_mode,
314
            skip_tokenizer_init=skip_tokenizer_init,
315
            trust_remote_code=trust_remote_code,
316
            allowed_local_media_path=allowed_local_media_path,
317
            allowed_media_domains=allowed_media_domains,
318
319
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
320
            quantization=quantization,
321
            revision=revision,
322
            tokenizer_revision=tokenizer_revision,
323
324
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
325
            kv_cache_memory_bytes=kv_cache_memory_bytes,
326
            swap_space=swap_space,
327
            cpu_offload_gb=cpu_offload_gb,
328
            enforce_eager=enforce_eager,
329
            enable_return_routed_experts=enable_return_routed_experts,
330
            disable_custom_all_reduce=disable_custom_all_reduce,
331
            hf_token=hf_token,
332
            hf_overrides=hf_overrides,
333
            mm_processor_kwargs=mm_processor_kwargs,
334
            pooler_config=pooler_config,
335
            structured_outputs_config=structured_outputs_instance,
336
            profiler_config=profiler_config_instance,
337
            attention_config=attention_config_instance,
338
            compilation_config=compilation_config_instance,
339
            logits_processors=logits_processors,
340
341
            **kwargs,
        )
342

343
344
        log_non_default_args(engine_args)

345
        self.llm_engine = LLMEngine.from_engine_args(
346
347
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
348
        self.engine_class = type(self.llm_engine)
349

350
        self.request_counter = Counter()
351
        self.default_sampling_params: dict[str, Any] | None = None
352

353
354
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
355
356
        self.supported_tasks = supported_tasks

357
        self.model_config = self.llm_engine.model_config
358
        self.input_processor = self.llm_engine.input_processor
359
        self.io_processor = self.llm_engine.io_processor
360

361
362
363
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

364
    def get_tokenizer(self) -> TokenizerLike:
365
        return self.llm_engine.get_tokenizer()
366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

384
    def reset_mm_cache(self) -> None:
385
        self.input_processor.clear_mm_cache()
386
387
        self.llm_engine.reset_mm_cache()

388
    def get_default_sampling_params(self) -> SamplingParams:
389
        if self.default_sampling_params is None:
390
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
391
392
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
393
394
        return SamplingParams()

395
396
    def generate(
        self,
397
398
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
399
        *,
400
401
402
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
403
        tokenization_kwargs: dict[str, Any] | None = None,
404
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
405
406
        """Generates the completions for the input prompts.

407
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
408
409
410
411
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
412
            prompts: The prompts to the LLM. You may pass a sequence of prompts
413
                for batch inference. See [PromptType][vllm.inputs.PromptType]
414
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
415
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
416
417
418
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
419
                prompts and it is paired one by one with the prompt.
420
421
422
423
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
424
            lora_request: LoRA request to use for generation, if any.
425
426
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
427
428
429
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
430
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
431
432

        Returns:
433
            A list of `RequestOutput` objects containing the
434
435
            generated completions in the same order as the input prompts.
        """
436
        model_config = self.model_config
437
438
        runner_type = model_config.runner_type
        if runner_type != "generate":
439
440
441
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
442
443
                "generative model."
            )
444

445
        if sampling_params is None:
446
            sampling_params = self.get_default_sampling_params()
447

448
        self._validate_and_add_requests(
449
            prompts=prompts,
450
            params=sampling_params,
451
            use_tqdm=use_tqdm,
452
453
            lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
            tokenization_kwargs=tokenization_kwargs,
454
455
            priority=priority,
        )
456

457
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
458
        return self.engine_class.validate_outputs(outputs, RequestOutput)
459

460
    def _get_modality_specific_lora_reqs(
461
        self,
462
463
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
464
    ):
465
466
467
468
469
470
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
471
472
        if (
            lora_config is None
473
            or not self.model_config.is_multimodal_model
474
475
            or (lora_config and lora_config.default_mm_loras is None)
        ):
476
477
            return lora_request

478
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
479
            prompts = [prompts]
480

481
482
483
484
485
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
486
487
488

        return [
            self._resolve_single_prompt_mm_lora(
489
                prompt,
490
491
                opt_lora_req,
                lora_config.default_mm_loras,
492
493
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
494
495
        ]

496
497
498
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
499
500
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
501
502
503
504
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
505
            or not (mm_data := prompt.get("multi_modal_data") or {})
506
        ):
507
508
            return lora_request

509
510
511
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
512
513
514
515
516
517
518
519
520
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
521
522
523
                " will be skipped",
                intersection,
            )
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
539
540
                    "lora_request as we only apply one LoRARequest per prompt"
                )
541
542
543
544
545
546
547
548
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

549
550
    def collective_rpc(
        self,
551
552
        method: str | Callable[..., _R],
        timeout: float | None = None,
553
        args: tuple = (),
554
        kwargs: dict[str, Any] | None = None,
555
    ) -> list[_R]:
556
557
558
559
560
561
562
563
564
565
566
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
567
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
568
569
570
571
572
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
573

574
575
576
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
577
        """
578
579

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
580
581

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
582
        """
583
584
        Run a function directly on the model inside each worker,
        returning the result for each of them.
585
586
587
588
589
590

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
591
        """
592
        return self.llm_engine.apply_model(func)
593

594
595
    def _get_beam_search_lora_requests(
        self,
596
597
598
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
599
        """Get the optional lora request corresponding to each prompt."""
600
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
601
            raise ValueError(
602
603
                "Lora request list should be the same length as the prompts"
            )
604
605
606
607
608
609

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

610
611
    def beam_search(
        self,
612
        prompts: list[TokensPrompt | TextPrompt],
613
        params: BeamSearchParams,
614
        lora_request: list[LoRARequest] | LoRARequest | None = None,
615
        use_tqdm: bool = False,
616
        concurrency_limit: int | None = None,
617
    ) -> list[BeamSearchOutput]:
618
619
620
621
622
623
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
624
            params: The beam search parameters.
625
            lora_request: LoRA request to use for generation, if any.
626
            use_tqdm: Whether to use tqdm to display the progress bar.
627
628
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
629
        """
630
631
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
632
633
634
635
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
636
637
        length_penalty = params.length_penalty

638
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
639

640
641
642
643
644
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
645

646
647
648
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
649
650
                "Disabling progress bar."
            )
651
652
653
654
655
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

656
657
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
658
659
660
661
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
662
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
663
            return TokensPrompt(**token_prompt_kwargs)
664

665
666
667
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
668
        beam_search_params = SamplingParams(
669
670
671
672
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
673
        )
674
        instances: list[BeamSearchInstance] = []
675

676
        for lora_req, prompt in zip(lora_requests, prompts):
677
678
679
680
681
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
682
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
683

684
685
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
686
687
688
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
689

690
            instances.append(
691
692
693
694
695
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
696
697
                ),
            )
698

699
        for prompt_start in range(0, len(prompts), concurrency_limit):
700
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
701
702
703

            token_iter = range(max_tokens)
            if use_tqdm:
704
705
706
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
707
708
709
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
710
711
                    "reflect instance-level progress."
                )
712
713
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
714
715
                    sum((instance.beams for instance in instances_batch), [])
                )
716
717
                pos = [0] + list(
                    itertools.accumulate(
718
719
720
                        len(instance.beams) for instance in instances_batch
                    )
                )
721
                instance_start_and_end: list[tuple[int, int]] = list(
722
723
                    zip(pos[:-1], pos[1:])
                )
724
725
726
727
728
729

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
730
731
732
733
734
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
735
736
737

                # only runs for one step
                # we don't need to use tqdm here
738
739
740
741
742
743
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
744

745
746
747
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
748
749
750
751
752
753
754
755
756
757
758
759
760
761
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
762
                                    logprobs=current_beam.logprobs + [logprobs],
763
                                    lora_request=current_beam.lora_request,
764
765
766
767
768
769
770
771
772
773
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
774
775
776
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
777
778
779
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
780
                    instance.beams = sorted_beams[:beam_width]
781
782
783
784

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
785
786
787
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
788
789
790
791
792
793
794
795
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

796
797
798
799
800
801
802
803
804
805
806
807
808
809
    def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            add_special_tokens=not model_config.is_encoder_decoder,
        ).with_kwargs(tokenization_kwargs)

    def _normalize_prompts(
nunjunj's avatar
nunjunj committed
810
        self,
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
        prompts: PromptType | Sequence[PromptType],
    ) -> list[EnginePrompt | EngineEncDecPrompt]:
        if isinstance(prompts, str):
            prompts = TextPrompt(prompt=prompts)

        return prompts if isinstance(prompts, Sequence) else [prompts]  # type: ignore[return-value]

    def _preprocess_cmpl_singleton(
        self,
        prompt: SingletonPrompt,
        tok_params: TokenizeParams,
        *,
        tokenize: bool,
    ) -> EnginePrompt:
        renderer = self.llm_engine.renderer

        if not isinstance(prompt, dict):
            prompt = renderer.render_completion(prompt)

        return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt

    def _preprocess_cmpl_enc_dec(
        self,
        prompt: ExplicitEncoderDecoderPrompt,
        tok_params: TokenizeParams,
    ) -> EngineEncDecPrompt:
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

        return EngineEncDecPrompt(
            encoder_prompt=self._preprocess_cmpl_singleton(
                enc_prompt,
                tok_params,
                # TODO: Move multi-modal processor into tokenization
                tokenize=not self.model_config.is_multimodal_model,
            ),
            decoder_prompt=(
                None
                if dec_prompt is None
                else self._preprocess_cmpl_singleton(
                    dec_prompt,
                    tok_params,
                    # TODO: Move multi-modal processor into tokenization
                    tokenize=not self.model_config.is_multimodal_model,
                )
            ),
        )

    def _preprocess_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[EnginePrompt | EngineEncDecPrompt]:
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
        """
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

        engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
        for prompt in self._normalize_prompts(prompts):
            if is_explicit_encoder_decoder_prompt(prompt):
                engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
            else:
                # Some MM models have non-default `add_special_tokens`
                # TODO: Move multi-modal processor into tokenization
                engine_prompts.append(
                    self._preprocess_cmpl_singleton(
                        prompt,
                        tok_params,
                        tokenize=not self.model_config.is_multimodal_model,
                    )
                )

        return engine_prompts

    def _normalize_conversations(
        self,
        conversations: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
    ) -> list[list[ChatCompletionMessageParam]]:
        return conversations if is_list_of(conversations, list) else [conversations]  # type: ignore[list-item,return-value]

    def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        ).with_kwargs(tokenization_kwargs)

    def _preprocess_chat(
        self,
        conversations: list[ChatCompletionMessageParam]
913
914
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
915
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
916
        chat_template_kwargs: dict[str, Any] | None = None,
917
        add_generation_prompt: bool = True,
918
        continue_final_message: bool = False,
919
        tools: list[dict[str, Any]] | None = None,
920
        tokenization_kwargs: dict[str, Any] | None = None,
921
        mm_processor_kwargs: dict[str, Any] | None = None,
922
    ) -> list[EnginePrompt]:
nunjunj's avatar
nunjunj committed
923
        """
924
925
926
927
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
928
929

        Returns:
930
931
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
nunjunj's avatar
nunjunj committed
932
        """
933
        renderer = self.llm_engine.renderer
934

935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
                ),
            ),
        )
        tok_params = self._get_chat_tok_params(tokenization_kwargs)

        engine_prompts = list[EnginePrompt]()
        for conversation in self._normalize_conversations(conversations):
            _, in_prompt = renderer.render_messages(conversation, chat_params)
953
            if mm_processor_kwargs is not None:
954
                in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
955

956
            engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
957

958
        return engine_prompts
959
960
961

    def chat(
        self,
962
963
964
965
966
967
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
968
969
970
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
971
972
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
973
        tokenization_kwargs: dict[str, Any] | None = None,
974
        mm_processor_kwargs: dict[str, Any] | None = None,
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
1019
1020
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
1021
1022
1023
1024
1025

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1026
1027
        prompts = self._preprocess_chat(
            messages,
1028
1029
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1030
            chat_template_kwargs=chat_template_kwargs,
1031
1032
1033
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1034
            tokenization_kwargs=tokenization_kwargs,
1035
1036
1037
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
1038
        return self.generate(
1039
            prompts,
1040
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
1041
1042
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1043
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1044
1045
        )

1046
1047
    def encode(
        self,
1048
1049
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1050
        *,
1051
1052
1053
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1054
        pooling_task: PoolingTask | None = None,
1055
        tokenization_kwargs: dict[str, Any] | None = None,
1056
    ) -> list[PoolingRequestOutput]:
1057
1058
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1059

1060
        This class automatically batches the given prompts, considering
1061
1062
1063
1064
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1065
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1066
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1067
                for more details about the format of each prompt.
1068
1069
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1070
1071
1072
1073
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1074
            lora_request: LoRA request to use for generation, if any.
1075
            pooling_task: Override the pooling task to use.
1076
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1077
1078

        Returns:
1079
            A list of `PoolingRequestOutput` objects containing the
1080
            pooled hidden states in the same order as the input prompts.
1081
        """
1082

1083
        if pooling_task is None:
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1099

1100
        model_config = self.model_config
1101
        runner_type = model_config.runner_type
1102
        if runner_type != "pooling":
1103
1104
1105
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1106
1107
                "pooling model."
            )
1108

1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1123
1124
1125
1126
1127
1128
1129
1130
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1131
1132
                    "offline inference example for more details."
                )
1133
1134
1135
1136
1137
1138
1139

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
1154
1155
1156
1157

        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1158
1159

        for param in as_iter(pooling_params):
1160
1161
1162
1163
1164
            if param.task is None:
                param.task = pooling_task
            elif param.task != pooling_task:
                msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                raise ValueError(msg)
1165

1166
        self._validate_and_add_requests(
1167
            prompts=prompts,
1168
            params=pooling_params,
1169
            use_tqdm=use_tqdm,
1170
            lora_request=lora_request,
1171
            tokenization_kwargs=tokenization_kwargs,
1172
1173
        )

1174
        outputs = self._run_engine(use_tqdm=use_tqdm)
1175
1176

        model_outputs = self.engine_class.validate_outputs(
1177
1178
            outputs, PoolingRequestOutput
        )
1179
1180
1181
1182
1183

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1184
1185
                model_output=model_outputs
            )
1186
1187

            return [
1188
1189
1190
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1191
1192
1193
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1194
1195
1196
                    prompt_token_ids=[],
                    finished=True,
                )
1197
1198
1199
            ]
        else:
            return model_outputs
1200

1201
1202
    def embed(
        self,
1203
        prompts: PromptType | Sequence[PromptType],
1204
        *,
1205
1206
1207
1208
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1209
        tokenization_kwargs: dict[str, Any] | None = None,
1210
    ) -> list[EmbeddingRequestOutput]:
1211
1212
1213
1214
1215
1216
1217
1218
1219
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1220
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1221
                for more details about the format of each prompt.
1222
1223
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1224
1225
1226
1227
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1228
            lora_request: LoRA request to use for generation, if any.
1229
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1230
1231

        Returns:
1232
            A list of `EmbeddingRequestOutput` objects containing the
1233
1234
            embedding vectors in the same order as the input prompts.
        """
1235
        if "embed" not in self.supported_tasks:
1236
1237
            raise ValueError(
                "Embedding API is not supported by this model. "
1238
1239
                "Try converting the model using `--convert embed`."
            )
1240

1241
1242
1243
1244
1245
1246
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1247
1248
1249
1250
1251
1252
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1253
            tokenization_kwargs=tokenization_kwargs,
1254
        )
1255
1256
1257
1258
1259

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1260
        prompts: PromptType | Sequence[PromptType],
1261
        *,
1262
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1263
        use_tqdm: bool | Callable[..., tqdm] = True,
1264
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1265
        tokenization_kwargs: dict[str, Any] | None = None,
1266
    ) -> list[ClassificationRequestOutput]:
1267
1268
1269
1270
1271
1272
1273
1274
1275
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1276
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1277
                for more details about the format of each prompt.
1278
1279
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1280
1281
1282
1283
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1284
            lora_request: LoRA request to use for generation, if any.
1285
1286
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1287
        Returns:
1288
            A list of `ClassificationRequestOutput` objects containing the
1289
1290
            embedding vectors in the same order as the input prompts.
        """
1291
        if "classify" not in self.supported_tasks:
1292
            raise ValueError(
1293
                "Classification API is not supported by this model. "
1294
1295
                "Try converting the model using `--convert classify`."
            )
1296

1297
1298
1299
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1300
            pooling_params=pooling_params,
1301
1302
            lora_request=lora_request,
            pooling_task="classify",
1303
            tokenization_kwargs=tokenization_kwargs,
1304
        )
1305
1306
1307

        return [ClassificationRequestOutput.from_base(item) for item in items]

1308
1309
    def reward(
        self,
1310
        prompts: PromptType | Sequence[PromptType],
1311
1312
        /,
        *,
1313
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1314
1315
1316
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1317
        tokenization_kwargs: dict[str, Any] | None = None,
1318
1319
1320
1321
1322
1323
1324
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1325
                for more details about the format of each prompt.
1326
1327
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1328
1329
1330
1331
1332
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1333
1334
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1346
            pooling_task="token_classify",
1347
            tokenization_kwargs=tokenization_kwargs,
1348
1349
        )

1350
1351
    def _embedding_score(
        self,
1352
1353
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1354
1355
1356
1357
1358
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1359
    ) -> list[ScoringRequestOutput]:
1360
1361
        tokenizer = self.get_tokenizer()

1362
1363
1364
1365
1366
1367
1368
1369
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1370
        encoded_output = self.encode(
1371
            input_texts,
1372
1373
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1374
            pooling_params=pooling_params,
1375
            pooling_task="embed",
1376
            tokenization_kwargs=tokenization_kwargs,
1377
        )
1378

1379
1380
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1381
1382
1383
1384

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1385
        scores = _cosine_similarity(
1386
1387
1388
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1389
        )
1390

1391
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1392
1393
        return [ScoringRequestOutput.from_base(item) for item in items]

1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

        # Extract text from ScoreData
        text_1: list[str] = []
        for text in data_1:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_1.append(text)

        text_2: list[str] = []
        for text in data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_2.append(text)

        encoded_output: list[PoolingRequestOutput] = self.encode(
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

1475
1476
    def _cross_encoding_score(
        self,
1477
1478
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1479
1480
1481
1482
1483
1484
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1485
    ) -> list[ScoringRequestOutput]:
1486
        model_config = self.model_config
1487
        tokenizer = self.get_tokenizer()
1488
1489

        if isinstance(tokenizer, MistralTokenizer):
1490
            raise ValueError("Score API is not supported for Mistral tokenizer")
1491

1492
1493
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1494

1495
1496
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1497
1498
        elif pooling_params.task is None:
            pooling_params.task = "score"
1499

1500
        pooling_params_list = list[PoolingParams]()
1501

1502
        prompts = list[PromptType]()
1503

1504
1505
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1506
1507
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1508
                model_config=model_config,
1509
1510
1511
1512
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1513
                score_template=score_template,
1514
1515
            )

1516
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1517
1518
1519
1520
1521
1522
1523
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1524
            prompts.append(engine_prompt)
1525
1526

        self._validate_and_add_requests(
1527
            prompts=prompts,
1528
            params=pooling_params_list,
1529
            use_tqdm=use_tqdm,
1530
1531
1532
1533
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1534
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1535
1536
1537

        return [ScoringRequestOutput.from_base(item) for item in items]

1538
1539
    def score(
        self,
1540
1541
1542
1543
1544
1545
1546
1547
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1548
        /,
1549
        *,
1550
1551
1552
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1553
        tokenization_kwargs: dict[str, Any] | None = None,
1554
        chat_template: str | None = None,
1555
    ) -> list[ScoringRequestOutput]:
1556
1557
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1558

1559
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1560
1561
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1562
        The input pairs are used to build a list of prompts for the
1563
1564
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1565
1566
1567
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1568
        appropriate multi-modal models. For multi-modal inputs, ensure the
1569
        prompt structure matches the model's expected input format.
1570
1571

        Args:
1572
1573
1574
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1575
                the `data_2` list.
1576
            data_2: The data to pair with the query to form the input to
1577
                the LLM. Can be text or multi-modal data. See [PromptType]
1578
                [vllm.inputs.PromptType] for more details about the format of
1579
                each prompt.
1580
1581
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1582
1583
1584
1585
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1586
            lora_request: LoRA request to use for generation, if any.
1587
1588
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1589
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1590
        Returns:
1591
            A list of `ScoringRequestOutput` objects containing the
1592
1593
            generated scores in the same order as the input prompts.
        """
1594
        model_config = self.model_config
1595

1596
        runner_type = model_config.runner_type
1597
        if runner_type != "pooling":
1598
1599
1600
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1601
1602
                "pooling model."
            )
1603

1604
        supported_tasks = self.supported_tasks
1605
1606
1607
1608
1609
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1610
1611
1612
1613
1614
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1615

1616
1617
1618
1619
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1620
            raise ValueError("Score API is only enabled for num_labels == 1.")
1621

1622
1623
1624
1625
1626
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1627
1628
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1629

1630
1631
1632
1633
1634
1635
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1636

1637
1638
1639
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
        encode_kwargs = tok_params.get_encode_kwargs()

1640
        if model_config.is_cross_encoder:
1641
            return self._cross_encoding_score(
1642
1643
                score_data_1,
                score_data_2,
1644
1645
1646
1647
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1648
                score_template=chat_template,
1649
            )
1650
1651
1652
1653
1654
1655
1656
1657
1658
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1659
        else:
1660
            return self._embedding_score(
1661
1662
                score_data_1,
                score_data_2,
1663
1664
1665
1666
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1667
            )
1668

1669
1670
1671
1672
1673
1674
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1675
1676
1677
1678
1679
1680
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1681

1682
1683
1684
1685
1686
1687
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1688
        Args:
1689
1690
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1691
                is forgotten. Level 1 sleep is good for sleeping and waking
1692
1693
1694
1695
1696
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1697
                sleep is good for sleeping and waking up the engine to run a
1698
                different model or update the model, where previous model
1699
                weights are not needed. It reduces CPU memory pressure.
1700
        """
1701
        self.reset_prefix_cache()
1702
1703
        self.llm_engine.sleep(level=level)

1704
    def wake_up(self, tags: list[str] | None = None):
1705
        """
1706
1707
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1708

1709
        Args:
1710
1711
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1712
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1713
                wake_up should be called with all tags (or None) before the
1714
1715
1716
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1717

1718
1719
1720
1721
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1722
            A `MetricSnapshot` instance capturing the current state
1723
1724
1725
1726
1727
1728
1729
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1730
1731
    def _validate_and_add_requests(
        self,
1732
        prompts: PromptType | Sequence[PromptType],
1733
1734
1735
1736
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1737
        *,
1738
        use_tqdm: bool | Callable[..., tqdm] = True,
1739
        lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
1740
        tokenization_kwargs: dict[str, Any] | None = None,
1741
        priority: list[int] | None = None,
1742
    ) -> None:
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
        in_prompts = self._normalize_prompts(prompts)
        num_requests = len(in_prompts)

        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
                    f"and lora_request ({len(params)}) must be the same."
                )

            engine_params = params
        else:
            engine_params = [params] * num_requests

        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

            engine_lora_requests: Sequence[LoRARequest | None] = lora_request
        else:
            engine_lora_requests = [lora_request] * num_requests

        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )
        else:
            priority = [0] * num_requests

        if any(param.truncate_prompt_tokens is not None for param in engine_params):
            # TODO: Remove this after deprecating `param.truncate_prompt_tokens`
            # Then, move the code from the `else` block to the top and let
            # `self._preprocess_completion` handle prompt normalization
            engine_prompts = [
                engine_prompt
                for in_prompt, param in zip(in_prompts, engine_params)
                for engine_prompt in self._preprocess_completion(
                    [in_prompt],
                    tokenization_kwargs=merge_kwargs(
                        tokenization_kwargs,
                        dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
                    ),
                )
            ]
        else:
            engine_prompts = self._preprocess_completion(
                in_prompts,
                tokenization_kwargs=tokenization_kwargs,
1796
            )
1797

1798
        for sp in engine_params:
1799
1800
1801
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1802

Zhuohan Li's avatar
Zhuohan Li committed
1803
        # Add requests to the engine.
1804
        it = engine_prompts
1805
        if use_tqdm:
1806
1807
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1808

1809
        added_request_ids: list[str] = []
1810

1811
1812
1813
1814
        try:
            for i, prompt in enumerate(it):
                request_id = self._add_request(
                    prompt,
1815
1816
                    engine_params[i],
                    lora_request=engine_lora_requests[i],
1817
                    tokenization_kwargs=tokenization_kwargs,
1818
                    priority=priority[i],
1819
1820
1821
1822
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1823
                self.llm_engine.abort_request(added_request_ids, internal=True)
1824
            raise e
1825

1826
    def _add_request(
nunjunj's avatar
nunjunj committed
1827
        self,
1828
        prompt: PromptType,
1829
1830
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1831
        tokenization_kwargs: dict[str, Any] | None = None,
1832
        priority: int = 0,
1833
    ) -> str:
1834
        prompt_text, _, _ = get_prompt_components(prompt)
1835
        request_id = str(next(self.request_counter))
1836

1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )

        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

        tokenization_kwargs = tok_params.get_encode_kwargs()
        engine_request = self.input_processor.process_inputs(
1856
            request_id,
1857
            prompt,
1858
1859
            params,
            lora_request=lora_request,
1860
            tokenization_kwargs=tokenization_kwargs,
1861
            priority=priority,
1862
            supported_tasks=self.supported_tasks,
1863
1864
1865
1866
1867
1868
1869
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1870
            tokenization_kwargs=tokenization_kwargs,
1871
            priority=priority,
1872
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1873
        )
1874
        return engine_request.request_id
1875

1876
    def _run_engine(
1877
1878
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1879
1880
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1881
            num_requests = self.llm_engine.get_num_unfinished_requests()
1882
1883
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1884
1885
1886
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1887
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1888
            )
1889

Zhuohan Li's avatar
Zhuohan Li committed
1890
        # Run the engine.
1891
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1892
1893
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1894
1895
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1896
            for output in step_outputs:
1897
                if output.finished:
1898
1899
                    outputs.append(output)
                    if use_tqdm:
1900
1901
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1902
                            n = len(output.outputs)
1903
                            assert output.prompt_token_ids is not None
1904
                            total_in_toks += len(output.prompt_token_ids) * n
1905
1906
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1907
1908
1909
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1910
1911
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1912
1913
                                f"output: {out_spd:.2f} toks/s"
                            )
1914
                            pbar.update(n)
1915
1916
                        else:
                            pbar.update(1)
1917
1918
                        if pbar.n == num_requests:
                            pbar.refresh()
1919

1920
1921
        if use_tqdm:
            pbar.close()
1922
1923
1924
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1925
        return sorted(outputs, key=lambda x: int(x.request_id))
1926

1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr