llm.py 81.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
7
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, overload
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
44
45
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
46
from vllm.entrypoints.pooling.score.utils import (
47
    ScoreData,
48
49
50
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
51
    compute_maxsim_score,
52
    get_score_prompt,
53
    score_data_to_prompts,
54
    validate_score_input,
55
)
56
from vllm.entrypoints.utils import log_non_default_args
57
from vllm.inputs.data import (
58
    DataPrompt,
59
    ProcessorInputs,
60
61
62
63
64
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
65
from vllm.logger import init_logger
66
from vllm.lora.request import LoRARequest
67
from vllm.model_executor.layers.quantization import QuantizationMethods
68
69
70
71
72
73
74
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
75
from vllm.platforms import current_platform
76
from vllm.pooling_params import PoolingParams
77
from vllm.renderers import ChatParams, merge_kwargs
78
79
80
81
82
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
83
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
84
from vllm.tasks import PoolingTask
85
from vllm.tokenizers import TokenizerLike
yhu422's avatar
yhu422 committed
86
from vllm.usage.usage_lib import UsageContext
87
from vllm.utils.counter import Counter
88
from vllm.utils.mistral import is_mistral_tokenizer
89
from vllm.utils.tqdm_utils import maybe_tqdm
90
from vllm.v1.engine import PauseMode
91
from vllm.v1.engine.llm_engine import LLMEngine
92
from vllm.v1.sample.logits_processor import LogitsProcessor
93

94
95
96
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

97
98
logger = init_logger(__name__)

99
100
101
102
103
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
104
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
105
106
_R = TypeVar("_R", default=Any)

107
108

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
109
110
111
112
113
114
115
116
117
118
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
119
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
120
121
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
122
123
124
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
125
126
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
127
128
129
130
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
131
        allowed_media_domains: If set, only media URLs that belong to this
132
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
133
134
135
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
136
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
137
138
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
139
        quantization: The method used to quantize the model weights. Currently,
140
            we support "awq", "gptq", and "fp8" (experimental).
141
142
143
144
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
145
146
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
147
148
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
149
150
151
152
153
154
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
155
156
157
158
159
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
160
            compared with using gpu_memory_utilization. Note that
161
162
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
163
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
164
165
166
167
168
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
169
170
171
172
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
173
174
175
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
176
        enable_return_routed_experts: Whether to return routed experts.
177
178
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
179
        hf_token: The token to use as HTTP bearer authorization for remote files
180
            . If `True`, will use the token generated when running
181
            `hf auth login` (stored in `~/.cache/huggingface/token`).
182
183
184
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
185
186
187
188
189
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
190
191
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
192
        compilation_config: Either an integer or a dictionary. If it is an
193
            integer, it is used as the mode of compilation optimization. If it
194
            is a dictionary, it can specify the full compilation configuration.
195
196
197
198
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
199
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
200

201
202
    Note:
        This class is intended to be used for offline inference. For online
203
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
204
    """
205
206
207
208

    def __init__(
        self,
        model: str,
209
        *,
210
211
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
212
        tokenizer: str | None = None,
213
        tokenizer_mode: TokenizerMode | str = "auto",
214
        skip_tokenizer_init: bool = False,
215
        trust_remote_code: bool = False,
216
        allowed_local_media_path: str = "",
217
        allowed_media_domains: list[str] | None = None,
218
        tensor_parallel_size: int = 1,
219
        dtype: ModelDType = "auto",
220
221
222
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
223
        seed: int = 0,
224
        gpu_memory_utilization: float = 0.9,
225
        swap_space: float = 4,
226
        cpu_offload_gb: float = 0,
227
        enforce_eager: bool = False,
228
        enable_return_routed_experts: bool = False,
229
        disable_custom_all_reduce: bool = False,
230
231
232
233
234
235
236
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
237
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
238
        attention_config: dict[str, Any] | AttentionConfig | None = None,
239
240
241
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
242
        **kwargs: Any,
243
    ) -> None:
244
        """LLM constructor."""
245

246
247
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
248

249
250
251
252
253
254
255
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

256
        if "kv_transfer_config" in kwargs and isinstance(
257
258
            kwargs["kv_transfer_config"], dict
        ):
259
            from vllm.config.kv_transfer import KVTransferConfig
260

261
262
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
263
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
264
265
266
267
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
268
269
270
                    raw_config_dict,
                    e,
                )
271
272
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
273
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
274

275
276
277
        if hf_overrides is None:
            hf_overrides = {}

278
279
280
281
282
283
284
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
285

286
287
288
289
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
290
        else:
291
292
293
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
294

295
296
297
298
299
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
300

301
        # warn about single-process data parallel usage.
302
303
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
304
305
306
307
308
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
309
            raise ValueError(
310
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
311
312
313
314
315
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
316
        engine_args = EngineArgs(
317
            model=model,
318
319
            runner=runner,
            convert=convert,
320
            tokenizer=tokenizer,
321
            tokenizer_mode=tokenizer_mode,
322
            skip_tokenizer_init=skip_tokenizer_init,
323
            trust_remote_code=trust_remote_code,
324
            allowed_local_media_path=allowed_local_media_path,
325
            allowed_media_domains=allowed_media_domains,
326
327
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
328
            quantization=quantization,
329
            revision=revision,
330
            tokenizer_revision=tokenizer_revision,
331
332
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
333
            kv_cache_memory_bytes=kv_cache_memory_bytes,
334
            swap_space=swap_space,
335
            cpu_offload_gb=cpu_offload_gb,
336
            enforce_eager=enforce_eager,
337
            enable_return_routed_experts=enable_return_routed_experts,
338
            disable_custom_all_reduce=disable_custom_all_reduce,
339
            hf_token=hf_token,
340
            hf_overrides=hf_overrides,
341
            mm_processor_kwargs=mm_processor_kwargs,
342
            pooler_config=pooler_config,
343
            structured_outputs_config=structured_outputs_instance,
344
            profiler_config=profiler_config_instance,
345
            attention_config=attention_config_instance,
346
            compilation_config=compilation_config_instance,
347
            logits_processors=logits_processors,
348
349
            **kwargs,
        )
350

351
352
        log_non_default_args(engine_args)

353
        self.llm_engine = LLMEngine.from_engine_args(
354
355
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
356
        self.engine_class = type(self.llm_engine)
357

358
        self.request_counter = Counter()
359
        self.default_sampling_params: dict[str, Any] | None = None
360

361
362
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
363
364
        self.supported_tasks = supported_tasks

365
        self.model_config = self.llm_engine.model_config
366
        self.renderer = self.llm_engine.renderer
367
        self.io_processor = self.llm_engine.io_processor
368
        self.input_processor = self.llm_engine.input_processor
369

370
371
372
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

373
    def get_tokenizer(self) -> TokenizerLike:
374
        return self.llm_engine.get_tokenizer()
375

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

393
    def reset_mm_cache(self) -> None:
394
        self.renderer.clear_mm_cache()
395
396
        self.llm_engine.reset_mm_cache()

397
    def get_default_sampling_params(self) -> SamplingParams:
398
        if self.default_sampling_params is None:
399
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
400
401
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
402
403
        return SamplingParams()

404
405
    def generate(
        self,
406
407
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
408
        *,
409
        use_tqdm: bool | Callable[..., tqdm] = True,
410
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
411
        priority: list[int] | None = None,
412
        tokenization_kwargs: dict[str, Any] | None = None,
413
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
414
415
        """Generates the completions for the input prompts.

416
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
417
418
419
420
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
421
            prompts: The prompts to the LLM. You may pass a sequence of prompts
422
                for batch inference. See [PromptType][vllm.inputs.PromptType]
423
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
424
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
425
426
427
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
428
                prompts and it is paired one by one with the prompt.
429
430
431
432
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
433
            lora_request: LoRA request to use for generation, if any.
434
435
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
436
437
438
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
439
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
440
441

        Returns:
442
            A list of `RequestOutput` objects containing the
443
444
            generated completions in the same order as the input prompts.
        """
445
        runner_type = self.model_config.runner_type
446
        if runner_type != "generate":
447
448
449
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
450
451
                "generative model."
            )
452

453
        if sampling_params is None:
454
            sampling_params = self.get_default_sampling_params()
455

456
        return self._run_completion(
457
            prompts=prompts,
458
            params=sampling_params,
459
            output_type=RequestOutput,
460
            use_tqdm=use_tqdm,
461
            lora_request=lora_request,
462
            tokenization_kwargs=tokenization_kwargs,
463
464
            priority=priority,
        )
465

466
467
468
469
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
470
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
492
        runner_type = self.model_config.runner_type
493
494
495
496
497
498
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

499
500
501
502
503
504
505
        return self._add_completion_requests(
            prompts=prompts,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
506
507
        )

508
    @overload
509
510
    def wait_for_completion(
        self,
511
        *,
512
        use_tqdm: bool | Callable[..., tqdm] = True,
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
529
530
531
532
533
534
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
535
            output_type: The expected output type, defaults to RequestOutput.
536
537
538
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
539
            A list of output objects for all completed requests.
540
        """
541
542
543
544
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
545

Cyrus Leung's avatar
Cyrus Leung committed
546
    def _resolve_mm_lora(
547
        self,
548
        prompt: ProcessorInputs,
549
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
550
551
552
553
554
555
556
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
557
558
            return lora_request

559
560
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
561
562
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
563

564
565
566
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
567
568
569
570
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
571
572
                intersection,
            )
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
588
589
                    "lora_request as we only apply one LoRARequest per prompt"
                )
590
591
592
593
594
595
596
597
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

598
599
    def collective_rpc(
        self,
600
601
        method: str | Callable[..., _R],
        timeout: float | None = None,
602
        args: tuple = (),
603
        kwargs: dict[str, Any] | None = None,
604
    ) -> list[_R]:
605
606
607
608
609
610
611
612
613
614
615
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
616
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
617
618
619
620
621
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
622

623
624
625
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
626
        """
627
628

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
629
630

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
631
        """
632
633
        Run a function directly on the model inside each worker,
        returning the result for each of them.
634
635
636
637
638
639

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
640
        """
641
        return self.llm_engine.apply_model(func)
642

643
644
    def beam_search(
        self,
645
        prompts: list[TokensPrompt | TextPrompt],
646
        params: BeamSearchParams,
647
        lora_request: list[LoRARequest] | LoRARequest | None = None,
648
        use_tqdm: bool = False,
649
        concurrency_limit: int | None = None,
650
    ) -> list[BeamSearchOutput]:
651
652
653
654
655
656
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
657
            params: The beam search parameters.
658
            lora_request: LoRA request to use for generation, if any.
659
            use_tqdm: Whether to use tqdm to display the progress bar.
660
661
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
662
        """
663
664
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
665
666
667
668
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
669
670
        length_penalty = params.length_penalty

671
672
673
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
674

675
676
        engine_prompts = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
677

678
679
680
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
681
682
                "Disabling progress bar."
            )
683
684
685
            use_tqdm = False

        if concurrency_limit is None:
686
            concurrency_limit = len(engine_prompts)
687

688
689
690
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
691
        sampling_params = SamplingParams(
692
693
694
695
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
696
        )
697
        instances: list[BeamSearchInstance] = []
698

699
700
701
702
703
704
705
706
707
        for lora_req, prompt in zip(lora_requests, engine_prompts):
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
            if prompt["type"] == "enc_dec":
                raise NotImplementedError(
                    "Encoder-decoder prompt not supported for beam search"
                )
708

709
            instances.append(
710
                BeamSearchInstance(
711
                    prompt,
712
713
                    lora_request=lora_req,
                    logprobs=None,
714
715
                ),
            )
716

717
        for prompt_start in range(0, len(instances), concurrency_limit):
718
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
719
720
721

            token_iter = range(max_tokens)
            if use_tqdm:
722
723
724
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
725
726
727
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
728
729
                    "reflect instance-level progress."
                )
730
731
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
732
733
                    sum((instance.beams for instance in instances_batch), [])
                )
734
735
                pos = [0] + list(
                    itertools.accumulate(
736
737
738
                        len(instance.beams) for instance in instances_batch
                    )
                )
739
                instance_start_and_end: list[tuple[int, int]] = list(
740
741
                    zip(pos[:-1], pos[1:])
                )
742
743
744
745
746
747

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
748
                output = self._render_and_run_requests(
749
750
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
751
                    output_type=RequestOutput,
752
                    lora_requests=[beam.lora_request for beam in all_beams],
753
754
                    use_tqdm=False,
                )
755

756
757
758
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
759
760
761
762
763
764
765
766
767
768
769
770
771
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
772
                                    current_beam.orig_prompt,
773
                                    tokens=current_beam.tokens + [token_id],
774
                                    logprobs=current_beam.logprobs + [logprobs],
775
                                    lora_request=current_beam.lora_request,
776
777
778
779
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

780
                                if token_id == eos_token_id and not ignore_eos:
781
782
783
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
784
785
786
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
787
                    instance.beams = sorted_beams[:beam_width]
788
789
790
791

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
792
793
794
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
795
796
797
798
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
799

800
801
802
803
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

804
    def _preprocess_cmpl(
805
        self,
806
        prompts: Sequence[PromptType],
807
        tokenization_kwargs: dict[str, Any] | None = None,
808
    ) -> Sequence[ProcessorInputs]:
809
810
811
812
813
814
815
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
816
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
817
        """
818
        renderer = self.renderer
819
820
        model_config = self.model_config

821
822
823
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
824
825
826
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
827

828
        return renderer.render_cmpl(parsed_prompts, tok_params)
829

830
831
832
833
834
835
836
837
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_prompt

838
839
    def _preprocess_chat(
        self,
840
        conversations: Sequence[list[ChatCompletionMessageParam]],
841
        chat_template: str | None = None,
842
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
843
        chat_template_kwargs: dict[str, Any] | None = None,
844
        add_generation_prompt: bool = True,
845
        continue_final_message: bool = False,
846
        tools: list[dict[str, Any]] | None = None,
847
        tokenization_kwargs: dict[str, Any] | None = None,
848
        mm_processor_kwargs: dict[str, Any] | None = None,
849
    ) -> Sequence[ProcessorInputs]:
nunjunj's avatar
nunjunj committed
850
        """
851
852
853
854
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
855
856

        Returns:
857
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
858
        """
859
        renderer = self.renderer
860

861
862
863
864
865
866
867
868
869
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
870
                    tokenize=is_mistral_tokenizer(renderer.tokenizer),
871
872
873
                ),
            ),
        )
874
875
876
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
877

878
879
880
881
882
883
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
884

885
        return engine_prompts
886

887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_chat(
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        return engine_prompt

913
914
    def chat(
        self,
915
        messages: list[ChatCompletionMessageParam]
916
917
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
918
        use_tqdm: bool | Callable[..., tqdm] = True,
919
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
920
        chat_template: str | None = None,
921
922
923
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
924
925
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
926
        tokenization_kwargs: dict[str, Any] | None = None,
927
        mm_processor_kwargs: dict[str, Any] | None = None,
928
929
930
931
932
933
934
935
936
937
938
939
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
940
            messages: A sequence of conversations or a single conversation.
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
972
973
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
974
975
976
977
978

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
979
980
981
982
983
984
985
986
987
988
989
990
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

991
        return self._run_chat(
992
993
            messages=messages,
            params=sampling_params,
994
            output_type=RequestOutput,
995
996
            use_tqdm=use_tqdm,
            lora_request=lora_request,
997
998
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
999
            chat_template_kwargs=chat_template_kwargs,
1000
1001
1002
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1003
            tokenization_kwargs=tokenization_kwargs,
1004
1005
1006
            mm_processor_kwargs=mm_processor_kwargs,
        )

1007
1008
    def encode(
        self,
1009
1010
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1011
        *,
1012
1013
1014
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1015
        pooling_task: PoolingTask | None = None,
1016
        tokenization_kwargs: dict[str, Any] | None = None,
1017
    ) -> list[PoolingRequestOutput]:
1018
1019
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1020

1021
        This class automatically batches the given prompts, considering
1022
1023
1024
1025
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1026
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1027
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1028
                for more details about the format of each prompt.
1029
1030
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1031
1032
1033
1034
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1035
            lora_request: LoRA request to use for generation, if any.
1036
            pooling_task: Override the pooling task to use.
1037
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1038
1039

        Returns:
1040
            A list of `PoolingRequestOutput` objects containing the
1041
            pooled hidden states in the same order as the input prompts.
1042
        """
1043

1044
        if pooling_task is None:
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1060

1061
        model_config = self.model_config
1062
        runner_type = model_config.runner_type
1063
        if runner_type != "pooling":
1064
1065
1066
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1067
1068
                "pooling model."
            )
1069

1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1084
        if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
1085
1086
1087
1088
1089
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1090
1091
                    "offline inference example for more details."
                )
1092
1093

            # Validate the request data is valid for the loaded plugin
1094
1095
1096
1097
1098
1099
1100
1101
1102
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1103
1104
1105

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1106
            prompts_seq = prompt_to_seq(prompts)
1107

1108
1109
1110
1111
1112
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1113
                )
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1134

1135
        outputs = self._run_completion(
1136
1137
            prompts=prompts_seq,
            params=params_seq,
1138
            output_type=PoolingRequestOutput,
1139
            use_tqdm=use_tqdm,
1140
            lora_request=lora_request,
1141
            tokenization_kwargs=tokenization_kwargs,
1142
1143
        )

1144
        if use_io_processor:
1145
1146
            # get the post-processed model outputs
            assert self.io_processor is not None
1147
            processed_outputs = self.io_processor.post_process(outputs)
1148
1149

            return [
1150
1151
1152
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1153
1154
1155
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1156
1157
1158
                    prompt_token_ids=[],
                    finished=True,
                )
1159
            ]
1160
1161

        return outputs
1162

1163
1164
    def embed(
        self,
1165
        prompts: PromptType | Sequence[PromptType],
1166
        *,
1167
1168
1169
1170
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1171
        tokenization_kwargs: dict[str, Any] | None = None,
1172
    ) -> list[EmbeddingRequestOutput]:
1173
1174
1175
1176
1177
1178
1179
1180
1181
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1182
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1183
                for more details about the format of each prompt.
1184
1185
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1186
1187
1188
1189
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1190
            lora_request: LoRA request to use for generation, if any.
1191
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1192
1193

        Returns:
1194
            A list of `EmbeddingRequestOutput` objects containing the
1195
1196
            embedding vectors in the same order as the input prompts.
        """
1197
        if "embed" not in self.supported_tasks:
1198
1199
            raise ValueError(
                "Embedding API is not supported by this model. "
1200
1201
                "Try converting the model using `--convert embed`."
            )
1202

1203
1204
1205
1206
1207
1208
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1209
1210
1211
1212
1213
1214
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1215
            tokenization_kwargs=tokenization_kwargs,
1216
        )
1217
1218
1219
1220
1221

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1222
        prompts: PromptType | Sequence[PromptType],
1223
        *,
1224
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1225
        use_tqdm: bool | Callable[..., tqdm] = True,
1226
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1227
        tokenization_kwargs: dict[str, Any] | None = None,
1228
    ) -> list[ClassificationRequestOutput]:
1229
1230
1231
1232
1233
1234
1235
1236
1237
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1238
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1239
                for more details about the format of each prompt.
1240
1241
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1242
1243
1244
1245
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1246
            lora_request: LoRA request to use for generation, if any.
1247
1248
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1249
        Returns:
1250
            A list of `ClassificationRequestOutput` objects containing the
1251
1252
            embedding vectors in the same order as the input prompts.
        """
1253
        if "classify" not in self.supported_tasks:
1254
            raise ValueError(
1255
                "Classification API is not supported by this model. "
1256
1257
                "Try converting the model using `--convert classify`."
            )
1258

1259
1260
1261
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1262
            pooling_params=pooling_params,
1263
1264
            lora_request=lora_request,
            pooling_task="classify",
1265
            tokenization_kwargs=tokenization_kwargs,
1266
        )
1267
1268
1269

        return [ClassificationRequestOutput.from_base(item) for item in items]

1270
1271
    def reward(
        self,
1272
        prompts: PromptType | Sequence[PromptType],
1273
1274
        /,
        *,
1275
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1276
1277
1278
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1279
        tokenization_kwargs: dict[str, Any] | None = None,
1280
1281
1282
1283
1284
1285
1286
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1287
                for more details about the format of each prompt.
1288
1289
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1290
1291
1292
1293
1294
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1295
1296
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1308
            pooling_task="token_classify",
1309
            tokenization_kwargs=tokenization_kwargs,
1310
1311
        )

1312
1313
    def _embedding_score(
        self,
1314
1315
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1316
1317
1318
1319
1320
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1321
    ) -> list[ScoringRequestOutput]:
1322
1323
        tokenizer = self.get_tokenizer()

1324
1325
1326
1327
1328
1329
1330
1331
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1332
        encoded_output = self.encode(
1333
            input_texts,
1334
1335
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1336
            pooling_params=pooling_params,
1337
            pooling_task="embed",
1338
            tokenization_kwargs=tokenization_kwargs,
1339
        )
1340

1341
1342
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1343
1344
1345
1346

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1347
        scores = _cosine_similarity(
1348
1349
1350
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1351
        )
1352

1353
        return [ScoringRequestOutput.from_base(item) for item in scores]
1354

1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

1375
1376
1377
1378
        # Convert ScoreData to PromptType (handles both text and multimodal)
        model_config = self.model_config
        prompts_1 = score_data_to_prompts(data_1, "query", model_config)
        prompts_2 = score_data_to_prompts(data_2, "document", model_config)
1379

1380
1381
        encoded_output: list[PoolingRequestOutput] = self.encode(
            prompts_1 + prompts_2,
1382
1383
1384
1385
1386
1387
1388
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

1389
1390
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[: len(prompts_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(prompts_1) :]
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

1421
        return [ScoringRequestOutput.from_base(item) for item in scores]
1422

1423
1424
    def _cross_encoding_score(
        self,
1425
1426
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1427
1428
1429
1430
1431
1432
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1433
    ) -> list[ScoringRequestOutput]:
1434
        model_config = self.model_config
1435
        tokenizer = self.get_tokenizer()
1436

1437
        if is_mistral_tokenizer(tokenizer):
1438
            raise ValueError("Score API is not supported for Mistral tokenizer")
1439

1440
1441
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1442

1443
1444
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1445
1446
        elif pooling_params.task is None:
            pooling_params.task = "score"
1447

1448
        pooling_params_list = list[PoolingParams]()
1449

1450
        prompts = list[PromptType]()
1451

1452
1453
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1454
1455
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1456
                model_config=model_config,
1457
1458
1459
1460
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1461
                score_template=score_template,
1462
1463
            )

1464
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1465
1466
1467
1468
1469
1470
1471
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1472
            prompts.append(engine_prompt)
1473

1474
        outputs = self._run_completion(
1475
            prompts=prompts,
1476
            params=pooling_params_list,
1477
            output_type=PoolingRequestOutput,
1478
            use_tqdm=use_tqdm,
1479
1480
1481
            lora_request=lora_request,
        )

1482
        return [ScoringRequestOutput.from_base(item) for item in outputs]
1483

1484
1485
    def score(
        self,
1486
1487
1488
1489
1490
1491
1492
1493
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1494
        /,
1495
        *,
1496
1497
1498
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1499
        tokenization_kwargs: dict[str, Any] | None = None,
1500
        chat_template: str | None = None,
1501
    ) -> list[ScoringRequestOutput]:
1502
1503
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1504

1505
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1506
1507
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1508
        The input pairs are used to build a list of prompts for the
1509
1510
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1511
1512
1513
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1514
        appropriate multi-modal models. For multi-modal inputs, ensure the
1515
        prompt structure matches the model's expected input format.
1516
1517

        Args:
1518
1519
1520
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1521
                the `data_2` list.
1522
            data_2: The data to pair with the query to form the input to
1523
                the LLM. Can be text or multi-modal data. See [PromptType]
1524
                [vllm.inputs.PromptType] for more details about the format of
1525
                each prompt.
1526
1527
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1528
1529
1530
1531
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1532
            lora_request: LoRA request to use for generation, if any.
1533
1534
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1535
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1536
        Returns:
1537
            A list of `ScoringRequestOutput` objects containing the
1538
1539
            generated scores in the same order as the input prompts.
        """
1540
        model_config = self.model_config
1541

1542
        runner_type = model_config.runner_type
1543
        if runner_type != "pooling":
1544
1545
1546
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1547
1548
                "pooling model."
            )
1549

1550
        supported_tasks = self.supported_tasks
1551
1552
1553
1554
1555
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1556
1557
1558
1559
1560
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1561

1562
1563
1564
1565
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1566
            raise ValueError("Score API is only enabled for num_labels == 1.")
1567

1568
1569
1570
1571
1572
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1573
1574
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1575

1576
1577
1578
1579
1580
1581
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1582

1583
1584
1585
1586
        renderer = self.renderer
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
1587
1588
        encode_kwargs = tok_params.get_encode_kwargs()

1589
        if model_config.is_cross_encoder:
1590
            return self._cross_encoding_score(
1591
1592
                score_data_1,
                score_data_2,
1593
1594
1595
1596
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1597
                score_template=chat_template,
1598
            )
1599
1600
1601
1602
1603
1604
1605
1606
1607
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1608
        else:
1609
            return self._embedding_score(
1610
1611
                score_data_1,
                score_data_2,
1612
1613
1614
1615
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1616
            )
1617

1618
1619
1620
1621
1622
1623
1624
1625
1626
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1627
1628
1629
1630

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1631
1632
1633
1634
1635
1636
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1637

1638
    def sleep(self, level: int = 1, mode: PauseMode = "abort"):
1639
1640
1641
1642
1643
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1644
        Args:
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1658
1659
            mode: How to handle any existing requests, can be "abort", "wait",
                or "keep".
1660
        """
1661
        self.llm_engine.sleep(level=level, mode=mode)
1662

1663
    def wake_up(self, tags: list[str] | None = None):
1664
        """
1665
1666
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1667

1668
        Args:
1669
1670
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1671
1672
1673
1674
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1675
1676
        """
        self.llm_engine.wake_up(tags)
1677

1678
1679
1680
1681
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1682
            A `MetricSnapshot` instance capturing the current state
1683
1684
1685
1686
1687
1688
1689
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1690
    def _params_to_seq(
1691
        self,
1692
        params: _P | Sequence[_P],
1693
        num_requests: int,
1694
    ) -> Sequence[_P]:
1695
1696
1697
1698
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1699
                    f"and params ({len(params)}) must be the same."
1700
1701
                )

1702
            return params
1703

1704
1705
1706
1707
1708
1709
1710
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1711
1712
1713
1714
1715
1716
1717
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1718
1719
1720
            return lora_request

        return [lora_request] * num_requests
1721

1722
1723
1724
1725
1726
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1727
1728
1729
1730
1731
1732
1733
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1734
1735
1736
1737
            return priority

        return [0] * num_requests

1738
    def _add_completion_requests(
1739
1740
1741
1742
1743
1744
1745
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1746
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1747
1748
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
1749
    ) -> list[str]:
1750
1751
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]
        seq_priority = self._priority_to_seq(priority, len(prompts))

1762
        return self._render_and_add_requests(
1763
1764
1765
1766
            prompts=(
                self._preprocess_cmpl_one(prompt, tok_kwargs)
                for prompt, tok_kwargs in zip(
                    maybe_tqdm(
1767
                        seq_prompts, use_tqdm=use_tqdm, desc="Rendering prompts"
1768
                    ),
1769
                    seq_tok_kwargs,
1770
                )
1771
            ),
1772
            params=seq_params,
1773
1774
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1775
1776
        )

1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        output_type: type[_O],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        self._add_completion_requests(
            prompts=prompts,
            params=params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
        )
        return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)

1800
1801
1802
1803
1804
1805
1806
    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1807
        output_type: type[_O],
1808
1809
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1810
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1811
1812
1813
1814
1815
1816
1817
1818
1819
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenization_kwargs=tok_kwargs,
                    mm_processor_kwargs=mm_processor_kwargs,
                )
                for conversation, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_convs,
                        use_tqdm=use_tqdm,
                        desc="Rendering conversations",
                    ),
                    seq_tok_kwargs,
                )
            ),
            params=seq_params,
1854
            output_type=output_type,
1855
1856
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1857
1858
        )

1859
1860
1861
1862
    def _render_and_run_requests(
        self,
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1863
        output_type: type[_O],
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1881
            params=params,
1882
1883
            lora_requests=lora_requests,
            priorities=priorities,
1884
1885
        )

1886
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1887

1888
    def _render_and_add_requests(
1889
        self,
1890
1891
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1892
        *,
1893
1894
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1895
    ) -> list[str]:
1896
        added_request_ids: list[str] = []
1897

1898
        try:
1899
            for i, prompt in enumerate(prompts):
1900
1901
                request_id = self._add_request(
                    prompt,
1902
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
1903
1904
1905
1906
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
1907
                    priority=0 if priorities is None else priorities[i],
1908
1909
1910
1911
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1912
                self.llm_engine.abort_request(added_request_ids, internal=True)
1913
            raise e
1914

1915
1916
        return added_request_ids

1917
    def _add_request(
nunjunj's avatar
nunjunj committed
1918
        self,
1919
        prompt: ProcessorInputs,
1920
1921
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1922
        priority: int = 0,
1923
    ) -> str:
1924
1925
1926
1927
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1928
        request_id = str(next(self.request_counter))
1929

1930
        return self.llm_engine.add_request(
1931
            request_id,
1932
            prompt,
1933
1934
            params,
            lora_request=lora_request,
1935
            priority=priority,
nunjunj's avatar
nunjunj committed
1936
        )
1937

1938
    def _run_engine(
1939
        self,
1940
        output_type: type[_O] | tuple[type[_O], ...],
1941
1942
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1943
    ) -> list[_O]:
1944
1945
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1946
            num_requests = self.llm_engine.get_num_unfinished_requests()
1947
1948
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1949
1950
1951
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1952
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1953
            )
1954

Zhuohan Li's avatar
Zhuohan Li committed
1955
        # Run the engine.
1956
        outputs: list[_O] = []
1957
1958
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1959
1960
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1961
            for output in step_outputs:
1962
                assert isinstance(output, output_type)
1963
                if output.finished:
1964
                    outputs.append(output)  # type: ignore[arg-type]
1965
                    if use_tqdm:
1966
1967
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1968
                            n = len(output.outputs)
1969
                            assert output.prompt_token_ids is not None
1970
                            total_in_toks += len(output.prompt_token_ids) * n
1971
1972
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1973
1974
1975
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1976
1977
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1978
1979
                                f"output: {out_spd:.2f} toks/s"
                            )
1980
                            pbar.update(n)
1981
1982
                        else:
                            pbar.update(1)
1983
1984
                        if pbar.n == num_requests:
                            pbar.refresh()
1985

1986
1987
        if use_tqdm:
            pbar.close()
1988
1989
1990
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1991
        return sorted(outputs, key=lambda x: int(x.request_id))
1992

1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr