llm.py 82 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
7
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, overload
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
44
45
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
46
from vllm.entrypoints.pooling.score.utils import (
47
    ScoreData,
48
49
50
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
51
    compute_maxsim_score,
52
    get_score_prompt,
53
    score_data_to_prompts,
54
    validate_score_input,
55
)
56
from vllm.entrypoints.utils import log_non_default_args
57
from vllm.inputs.data import (
58
    DataPrompt,
59
    ProcessorInputs,
60
61
62
63
64
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
65
from vllm.logger import init_logger
66
from vllm.lora.request import LoRARequest
67
from vllm.model_executor.layers.quantization import QuantizationMethods
68
69
70
71
72
73
74
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
75
from vllm.platforms import current_platform
76
from vllm.pooling_params import PoolingParams
77
from vllm.renderers import ChatParams, merge_kwargs
78
79
80
81
82
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
83
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
84
from vllm.tasks import PoolingTask
85
86
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
87
from vllm.usage.usage_lib import UsageContext
88
from vllm.utils.counter import Counter
89
from vllm.utils.tqdm_utils import maybe_tqdm
90
from vllm.v1.engine.llm_engine import LLMEngine
91
from vllm.v1.sample.logits_processor import LogitsProcessor
92

93
94
95
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

96
97
logger = init_logger(__name__)

98
99
100
101
102
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
103
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
104
105
_R = TypeVar("_R", default=Any)

106
107

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
108
109
110
111
112
113
114
115
116
117
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
118
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
119
120
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
121
122
123
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
124
125
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
126
127
128
129
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
130
        allowed_media_domains: If set, only media URLs that belong to this
131
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
132
133
134
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
135
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
136
137
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
138
        quantization: The method used to quantize the model weights. Currently,
139
            we support "awq", "gptq", and "fp8" (experimental).
140
141
142
143
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
144
145
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
146
147
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
148
149
150
151
152
153
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
154
155
156
157
158
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
159
            compared with using gpu_memory_utilization. Note that
160
161
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
162
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
163
164
165
166
167
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
168
169
170
171
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
172
173
174
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
175
        enable_return_routed_experts: Whether to return routed experts.
176
177
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
178
        hf_token: The token to use as HTTP bearer authorization for remote files
179
            . If `True`, will use the token generated when running
180
            `hf auth login` (stored in `~/.cache/huggingface/token`).
181
182
183
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
184
185
186
187
188
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
189
190
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
191
        compilation_config: Either an integer or a dictionary. If it is an
192
            integer, it is used as the mode of compilation optimization. If it
193
            is a dictionary, it can specify the full compilation configuration.
194
195
196
197
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
198
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
199

200
201
    Note:
        This class is intended to be used for offline inference. For online
202
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
203
    """
204
205
206
207

    def __init__(
        self,
        model: str,
208
        *,
209
210
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
211
        tokenizer: str | None = None,
212
        tokenizer_mode: TokenizerMode | str = "auto",
213
        skip_tokenizer_init: bool = False,
214
        trust_remote_code: bool = False,
215
        allowed_local_media_path: str = "",
216
        allowed_media_domains: list[str] | None = None,
217
        tensor_parallel_size: int = 1,
218
        dtype: ModelDType = "auto",
219
220
221
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
222
        seed: int = 0,
223
        gpu_memory_utilization: float = 0.9,
224
        swap_space: float = 4,
225
        cpu_offload_gb: float = 0,
226
        enforce_eager: bool = False,
227
        enable_return_routed_experts: bool = False,
228
        disable_custom_all_reduce: bool = False,
229
230
231
232
233
234
235
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
236
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
237
        attention_config: dict[str, Any] | AttentionConfig | None = None,
238
239
240
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
241
        **kwargs: Any,
242
    ) -> None:
243
        """LLM constructor."""
244

245
246
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
247

248
249
250
251
252
253
254
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

255
        if "kv_transfer_config" in kwargs and isinstance(
256
257
            kwargs["kv_transfer_config"], dict
        ):
258
            from vllm.config.kv_transfer import KVTransferConfig
259

260
261
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
262
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
263
264
265
266
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
267
268
269
                    raw_config_dict,
                    e,
                )
270
271
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
272
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
273

274
275
276
        if hf_overrides is None:
            hf_overrides = {}

277
278
279
280
281
282
283
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
284

285
286
287
288
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
289
        else:
290
291
292
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
293

294
295
296
297
298
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
299

300
        # warn about single-process data parallel usage.
301
302
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
303
304
305
306
307
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
308
            raise ValueError(
309
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
310
311
312
313
314
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
315
        engine_args = EngineArgs(
316
            model=model,
317
318
            runner=runner,
            convert=convert,
319
            tokenizer=tokenizer,
320
            tokenizer_mode=tokenizer_mode,
321
            skip_tokenizer_init=skip_tokenizer_init,
322
            trust_remote_code=trust_remote_code,
323
            allowed_local_media_path=allowed_local_media_path,
324
            allowed_media_domains=allowed_media_domains,
325
326
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
327
            quantization=quantization,
328
            revision=revision,
329
            tokenizer_revision=tokenizer_revision,
330
331
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
332
            kv_cache_memory_bytes=kv_cache_memory_bytes,
333
            swap_space=swap_space,
334
            cpu_offload_gb=cpu_offload_gb,
335
            enforce_eager=enforce_eager,
336
            enable_return_routed_experts=enable_return_routed_experts,
337
            disable_custom_all_reduce=disable_custom_all_reduce,
338
            hf_token=hf_token,
339
            hf_overrides=hf_overrides,
340
            mm_processor_kwargs=mm_processor_kwargs,
341
            pooler_config=pooler_config,
342
            structured_outputs_config=structured_outputs_instance,
343
            profiler_config=profiler_config_instance,
344
            attention_config=attention_config_instance,
345
            compilation_config=compilation_config_instance,
346
            logits_processors=logits_processors,
347
348
            **kwargs,
        )
349

350
351
        log_non_default_args(engine_args)

352
        self.llm_engine = LLMEngine.from_engine_args(
353
354
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
355
        self.engine_class = type(self.llm_engine)
356

357
        self.request_counter = Counter()
358
        self.default_sampling_params: dict[str, Any] | None = None
359

360
361
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
362
363
        self.supported_tasks = supported_tasks

364
        self.model_config = self.llm_engine.model_config
365
        self.renderer = self.llm_engine.renderer
366
        self.io_processor = self.llm_engine.io_processor
367
        self.input_processor = self.llm_engine.input_processor
368

369
370
371
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

372
    def get_tokenizer(self) -> TokenizerLike:
373
        return self.llm_engine.get_tokenizer()
374

375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

392
    def reset_mm_cache(self) -> None:
393
        self.renderer.clear_mm_cache()
394
395
        self.llm_engine.reset_mm_cache()

396
    def get_default_sampling_params(self) -> SamplingParams:
397
        if self.default_sampling_params is None:
398
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
399
400
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
401
402
        return SamplingParams()

403
404
    def generate(
        self,
405
406
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
407
        *,
408
        use_tqdm: bool | Callable[..., tqdm] = True,
409
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
410
        priority: list[int] | None = None,
411
        tokenization_kwargs: dict[str, Any] | None = None,
412
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
413
414
        """Generates the completions for the input prompts.

415
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
416
417
418
419
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
420
            prompts: The prompts to the LLM. You may pass a sequence of prompts
421
                for batch inference. See [PromptType][vllm.inputs.PromptType]
422
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
423
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
424
425
426
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
427
                prompts and it is paired one by one with the prompt.
428
429
430
431
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
432
            lora_request: LoRA request to use for generation, if any.
433
434
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
435
436
437
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
438
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
439
440

        Returns:
441
            A list of `RequestOutput` objects containing the
442
443
            generated completions in the same order as the input prompts.
        """
444
        model_config = self.model_config
445
446
        runner_type = model_config.runner_type
        if runner_type != "generate":
447
448
449
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
450
451
                "generative model."
            )
452

453
        if sampling_params is None:
454
            sampling_params = self.get_default_sampling_params()
455

456
        return self._run_completion(
457
            prompts=prompts,
458
            params=sampling_params,
459
            output_type=RequestOutput,
460
            use_tqdm=use_tqdm,
461
            lora_request=lora_request,
462
            tokenization_kwargs=tokenization_kwargs,
463
464
            priority=priority,
        )
465

466
467
468
469
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
470
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

        # Use the same preprocessing as _run_completion
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]
        seq_priority = self._priority_to_seq(priority, len(prompts))

        request_ids = self._render_and_add_requests(
            prompts=(
                self._preprocess_cmpl_one(prompt, tok_kwargs)
                for prompt, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_prompts,
                        use_tqdm=use_tqdm,
                        desc="Rendering prompts",
521
                    ),
522
                    seq_tok_kwargs,
523
524
                )
            ),
525
526
527
            params=seq_params,
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
528
529
530
531
        )

        return request_ids

532
    @overload
533
534
    def wait_for_completion(
        self,
535
        *,
536
        use_tqdm: bool | Callable[..., tqdm] = True,
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
553
554
555
556
557
558
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
559
            output_type: The expected output type, defaults to RequestOutput.
560
561
562
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
563
            A list of output objects for all completed requests.
564
        """
565
566
567
568
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
569

Cyrus Leung's avatar
Cyrus Leung committed
570
    def _resolve_mm_lora(
571
        self,
572
        prompt: ProcessorInputs,
573
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
574
575
576
577
578
579
580
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
581
582
            return lora_request

583
584
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
585
586
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
587

588
589
590
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
591
592
593
594
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
595
596
                intersection,
            )
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
612
613
                    "lora_request as we only apply one LoRARequest per prompt"
                )
614
615
616
617
618
619
620
621
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

622
623
    def collective_rpc(
        self,
624
625
        method: str | Callable[..., _R],
        timeout: float | None = None,
626
        args: tuple = (),
627
        kwargs: dict[str, Any] | None = None,
628
    ) -> list[_R]:
629
630
631
632
633
634
635
636
637
638
639
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
640
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
641
642
643
644
645
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
646

647
648
649
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
650
        """
651
652

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
653
654

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
655
        """
656
657
        Run a function directly on the model inside each worker,
        returning the result for each of them.
658
659
660
661
662
663

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
664
        """
665
        return self.llm_engine.apply_model(func)
666

667
668
    def beam_search(
        self,
669
        prompts: list[TokensPrompt | TextPrompt],
670
        params: BeamSearchParams,
671
        lora_request: list[LoRARequest] | LoRARequest | None = None,
672
        use_tqdm: bool = False,
673
        concurrency_limit: int | None = None,
674
    ) -> list[BeamSearchOutput]:
675
676
677
678
679
680
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
681
            params: The beam search parameters.
682
            lora_request: LoRA request to use for generation, if any.
683
            use_tqdm: Whether to use tqdm to display the progress bar.
684
685
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
686
        """
687
688
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
689
690
691
692
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
693
694
        length_penalty = params.length_penalty

695
696
697
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
698

699
700
        engine_prompts = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
701

702
703
704
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
705
706
                "Disabling progress bar."
            )
707
708
709
            use_tqdm = False

        if concurrency_limit is None:
710
            concurrency_limit = len(engine_prompts)
711

712
713
714
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
715
        sampling_params = SamplingParams(
716
717
718
719
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
720
        )
721
        instances: list[BeamSearchInstance] = []
722

723
724
725
726
727
728
729
730
731
        for lora_req, prompt in zip(lora_requests, engine_prompts):
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
            if prompt["type"] == "enc_dec":
                raise NotImplementedError(
                    "Encoder-decoder prompt not supported for beam search"
                )
732

733
            instances.append(
734
                BeamSearchInstance(
735
                    prompt,
736
737
                    lora_request=lora_req,
                    logprobs=None,
738
739
                ),
            )
740

741
        for prompt_start in range(0, len(instances), concurrency_limit):
742
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
743
744
745

            token_iter = range(max_tokens)
            if use_tqdm:
746
747
748
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
749
750
751
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
752
753
                    "reflect instance-level progress."
                )
754
755
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
756
757
                    sum((instance.beams for instance in instances_batch), [])
                )
758
759
                pos = [0] + list(
                    itertools.accumulate(
760
761
762
                        len(instance.beams) for instance in instances_batch
                    )
                )
763
                instance_start_and_end: list[tuple[int, int]] = list(
764
765
                    zip(pos[:-1], pos[1:])
                )
766
767
768
769
770
771

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
772
                output = self._render_and_run_requests(
773
774
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
775
                    output_type=RequestOutput,
776
                    lora_requests=[beam.lora_request for beam in all_beams],
777
778
                    use_tqdm=False,
                )
779

780
781
782
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
783
784
785
786
787
788
789
790
791
792
793
794
795
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
796
                                    current_beam.orig_prompt,
797
                                    tokens=current_beam.tokens + [token_id],
798
                                    logprobs=current_beam.logprobs + [logprobs],
799
                                    lora_request=current_beam.lora_request,
800
801
802
803
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

804
                                if token_id == eos_token_id and not ignore_eos:
805
806
807
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
808
809
810
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
811
                    instance.beams = sorted_beams[:beam_width]
812
813
814
815

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
816
817
818
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
819
820
821
822
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
823

824
825
826
827
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

828
    def _preprocess_cmpl(
829
        self,
830
        prompts: Sequence[PromptType],
831
        tokenization_kwargs: dict[str, Any] | None = None,
832
    ) -> Sequence[ProcessorInputs]:
833
834
835
836
837
838
839
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
840
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
841
        """
842
        renderer = self.renderer
843
844
        model_config = self.model_config

845
846
847
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
848
849
850
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
851

852
        return renderer.render_cmpl(parsed_prompts, tok_params)
853

854
855
856
857
858
859
860
861
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_prompt

862
863
    def _preprocess_chat(
        self,
864
        conversations: Sequence[list[ChatCompletionMessageParam]],
865
        chat_template: str | None = None,
866
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
867
        chat_template_kwargs: dict[str, Any] | None = None,
868
        add_generation_prompt: bool = True,
869
        continue_final_message: bool = False,
870
        tools: list[dict[str, Any]] | None = None,
871
        tokenization_kwargs: dict[str, Any] | None = None,
872
        mm_processor_kwargs: dict[str, Any] | None = None,
873
    ) -> Sequence[ProcessorInputs]:
nunjunj's avatar
nunjunj committed
874
        """
875
876
877
878
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
879
880

        Returns:
881
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
882
        """
883
        renderer = self.renderer
884

885
886
887
888
889
890
891
892
893
894
895
896
897
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
                ),
            ),
        )
898
899
900
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
901

902
903
904
905
906
907
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
908

909
        return engine_prompts
910

911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_chat(
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        return engine_prompt

937
938
    def chat(
        self,
939
        messages: list[ChatCompletionMessageParam]
940
941
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
942
        use_tqdm: bool | Callable[..., tqdm] = True,
943
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
944
        chat_template: str | None = None,
945
946
947
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
948
949
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
950
        tokenization_kwargs: dict[str, Any] | None = None,
951
        mm_processor_kwargs: dict[str, Any] | None = None,
952
953
954
955
956
957
958
959
960
961
962
963
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
964
            messages: A sequence of conversations or a single conversation.
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
996
997
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
998
999
1000
1001
1002

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

1015
        return self._run_chat(
1016
1017
            messages=messages,
            params=sampling_params,
1018
            output_type=RequestOutput,
1019
1020
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1021
1022
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1023
            chat_template_kwargs=chat_template_kwargs,
1024
1025
1026
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1027
            tokenization_kwargs=tokenization_kwargs,
1028
1029
1030
            mm_processor_kwargs=mm_processor_kwargs,
        )

1031
1032
    def encode(
        self,
1033
1034
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1035
        *,
1036
1037
1038
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1039
        pooling_task: PoolingTask | None = None,
1040
        tokenization_kwargs: dict[str, Any] | None = None,
1041
    ) -> list[PoolingRequestOutput]:
1042
1043
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1044

1045
        This class automatically batches the given prompts, considering
1046
1047
1048
1049
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1050
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1051
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1052
                for more details about the format of each prompt.
1053
1054
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1055
1056
1057
1058
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1059
            lora_request: LoRA request to use for generation, if any.
1060
            pooling_task: Override the pooling task to use.
1061
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1062
1063

        Returns:
1064
            A list of `PoolingRequestOutput` objects containing the
1065
            pooled hidden states in the same order as the input prompts.
1066
        """
1067

1068
        if pooling_task is None:
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1084

1085
        model_config = self.model_config
1086
        runner_type = model_config.runner_type
1087
        if runner_type != "pooling":
1088
1089
1090
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1091
1092
                "pooling model."
            )
1093

1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1108
        if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
1109
1110
1111
1112
1113
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1114
1115
                    "offline inference example for more details."
                )
1116
1117

            # Validate the request data is valid for the loaded plugin
1118
1119
1120
1121
1122
1123
1124
1125
1126
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1127
1128
1129

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1130
            prompts_seq = prompt_to_seq(prompts)
1131

1132
1133
1134
1135
1136
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1137
                )
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1158

1159
        outputs = self._run_completion(
1160
1161
            prompts=prompts_seq,
            params=params_seq,
1162
            output_type=PoolingRequestOutput,
1163
            use_tqdm=use_tqdm,
1164
            lora_request=lora_request,
1165
            tokenization_kwargs=tokenization_kwargs,
1166
1167
        )

1168
        if use_io_processor:
1169
1170
            # get the post-processed model outputs
            assert self.io_processor is not None
1171
            processed_outputs = self.io_processor.post_process(outputs)
1172
1173

            return [
1174
1175
1176
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1177
1178
1179
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1180
1181
1182
                    prompt_token_ids=[],
                    finished=True,
                )
1183
            ]
1184
1185

        return outputs
1186

1187
1188
    def embed(
        self,
1189
        prompts: PromptType | Sequence[PromptType],
1190
        *,
1191
1192
1193
1194
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1195
        tokenization_kwargs: dict[str, Any] | None = None,
1196
    ) -> list[EmbeddingRequestOutput]:
1197
1198
1199
1200
1201
1202
1203
1204
1205
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1206
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1207
                for more details about the format of each prompt.
1208
1209
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1210
1211
1212
1213
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1214
            lora_request: LoRA request to use for generation, if any.
1215
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1216
1217

        Returns:
1218
            A list of `EmbeddingRequestOutput` objects containing the
1219
1220
            embedding vectors in the same order as the input prompts.
        """
1221
        if "embed" not in self.supported_tasks:
1222
1223
            raise ValueError(
                "Embedding API is not supported by this model. "
1224
1225
                "Try converting the model using `--convert embed`."
            )
1226

1227
1228
1229
1230
1231
1232
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1233
1234
1235
1236
1237
1238
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1239
            tokenization_kwargs=tokenization_kwargs,
1240
        )
1241
1242
1243
1244
1245

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1246
        prompts: PromptType | Sequence[PromptType],
1247
        *,
1248
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1249
        use_tqdm: bool | Callable[..., tqdm] = True,
1250
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1251
        tokenization_kwargs: dict[str, Any] | None = None,
1252
    ) -> list[ClassificationRequestOutput]:
1253
1254
1255
1256
1257
1258
1259
1260
1261
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1262
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1263
                for more details about the format of each prompt.
1264
1265
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1266
1267
1268
1269
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1270
            lora_request: LoRA request to use for generation, if any.
1271
1272
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1273
        Returns:
1274
            A list of `ClassificationRequestOutput` objects containing the
1275
1276
            embedding vectors in the same order as the input prompts.
        """
1277
        if "classify" not in self.supported_tasks:
1278
            raise ValueError(
1279
                "Classification API is not supported by this model. "
1280
1281
                "Try converting the model using `--convert classify`."
            )
1282

1283
1284
1285
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1286
            pooling_params=pooling_params,
1287
1288
            lora_request=lora_request,
            pooling_task="classify",
1289
            tokenization_kwargs=tokenization_kwargs,
1290
        )
1291
1292
1293

        return [ClassificationRequestOutput.from_base(item) for item in items]

1294
1295
    def reward(
        self,
1296
        prompts: PromptType | Sequence[PromptType],
1297
1298
        /,
        *,
1299
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1300
1301
1302
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1303
        tokenization_kwargs: dict[str, Any] | None = None,
1304
1305
1306
1307
1308
1309
1310
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1311
                for more details about the format of each prompt.
1312
1313
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1314
1315
1316
1317
1318
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1319
1320
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1332
            pooling_task="token_classify",
1333
            tokenization_kwargs=tokenization_kwargs,
1334
1335
        )

1336
1337
    def _embedding_score(
        self,
1338
1339
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1340
1341
1342
1343
1344
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1345
    ) -> list[ScoringRequestOutput]:
1346
1347
        tokenizer = self.get_tokenizer()

1348
1349
1350
1351
1352
1353
1354
1355
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1356
        encoded_output = self.encode(
1357
            input_texts,
1358
1359
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1360
            pooling_params=pooling_params,
1361
            pooling_task="embed",
1362
            tokenization_kwargs=tokenization_kwargs,
1363
        )
1364

1365
1366
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1367
1368
1369
1370

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1371
        scores = _cosine_similarity(
1372
1373
1374
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1375
        )
1376

1377
        return [ScoringRequestOutput.from_base(item) for item in scores]
1378

1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

1399
1400
1401
1402
        # Convert ScoreData to PromptType (handles both text and multimodal)
        model_config = self.model_config
        prompts_1 = score_data_to_prompts(data_1, "query", model_config)
        prompts_2 = score_data_to_prompts(data_2, "document", model_config)
1403

1404
1405
        encoded_output: list[PoolingRequestOutput] = self.encode(
            prompts_1 + prompts_2,
1406
1407
1408
1409
1410
1411
1412
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

1413
1414
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[: len(prompts_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(prompts_1) :]
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

1445
        return [ScoringRequestOutput.from_base(item) for item in scores]
1446

1447
1448
    def _cross_encoding_score(
        self,
1449
1450
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1451
1452
1453
1454
1455
1456
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1457
    ) -> list[ScoringRequestOutput]:
1458
        model_config = self.model_config
1459
        tokenizer = self.get_tokenizer()
1460
1461

        if isinstance(tokenizer, MistralTokenizer):
1462
            raise ValueError("Score API is not supported for Mistral tokenizer")
1463

1464
1465
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1466

1467
1468
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1469
1470
        elif pooling_params.task is None:
            pooling_params.task = "score"
1471

1472
        pooling_params_list = list[PoolingParams]()
1473

1474
        prompts = list[PromptType]()
1475

1476
1477
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1478
1479
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1480
                model_config=model_config,
1481
1482
1483
1484
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1485
                score_template=score_template,
1486
1487
            )

1488
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1489
1490
1491
1492
1493
1494
1495
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1496
            prompts.append(engine_prompt)
1497

1498
        outputs = self._run_completion(
1499
            prompts=prompts,
1500
            params=pooling_params_list,
1501
            output_type=PoolingRequestOutput,
1502
            use_tqdm=use_tqdm,
1503
1504
1505
            lora_request=lora_request,
        )

1506
        return [ScoringRequestOutput.from_base(item) for item in outputs]
1507

1508
1509
    def score(
        self,
1510
1511
1512
1513
1514
1515
1516
1517
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1518
        /,
1519
        *,
1520
1521
1522
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1523
        tokenization_kwargs: dict[str, Any] | None = None,
1524
        chat_template: str | None = None,
1525
    ) -> list[ScoringRequestOutput]:
1526
1527
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1528

1529
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1530
1531
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1532
        The input pairs are used to build a list of prompts for the
1533
1534
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1535
1536
1537
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1538
        appropriate multi-modal models. For multi-modal inputs, ensure the
1539
        prompt structure matches the model's expected input format.
1540
1541

        Args:
1542
1543
1544
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1545
                the `data_2` list.
1546
            data_2: The data to pair with the query to form the input to
1547
                the LLM. Can be text or multi-modal data. See [PromptType]
1548
                [vllm.inputs.PromptType] for more details about the format of
1549
                each prompt.
1550
1551
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1552
1553
1554
1555
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1556
            lora_request: LoRA request to use for generation, if any.
1557
1558
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1559
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1560
        Returns:
1561
            A list of `ScoringRequestOutput` objects containing the
1562
1563
            generated scores in the same order as the input prompts.
        """
1564
        model_config = self.model_config
1565

1566
        runner_type = model_config.runner_type
1567
        if runner_type != "pooling":
1568
1569
1570
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1571
1572
                "pooling model."
            )
1573

1574
        supported_tasks = self.supported_tasks
1575
1576
1577
1578
1579
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1580
1581
1582
1583
1584
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1585

1586
1587
1588
1589
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1590
            raise ValueError("Score API is only enabled for num_labels == 1.")
1591

1592
1593
1594
1595
1596
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1597
1598
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1599

1600
1601
1602
1603
1604
1605
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1606

1607
1608
1609
1610
        renderer = self.renderer
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
1611
1612
        encode_kwargs = tok_params.get_encode_kwargs()

1613
        if model_config.is_cross_encoder:
1614
            return self._cross_encoding_score(
1615
1616
                score_data_1,
                score_data_2,
1617
1618
1619
1620
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1621
                score_template=chat_template,
1622
            )
1623
1624
1625
1626
1627
1628
1629
1630
1631
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1632
        else:
1633
            return self._embedding_score(
1634
1635
                score_data_1,
                score_data_2,
1636
1637
1638
1639
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1640
            )
1641

1642
1643
1644
1645
1646
1647
1648
1649
1650
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1651
1652
1653
1654

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1655
1656
1657
1658
1659
1660
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1661

1662
1663
1664
1665
1666
1667
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1668
        Args:
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1682
        """
1683
1684
        if level > 0:
            self.reset_prefix_cache()
1685
1686
        self.llm_engine.sleep(level=level)

1687
    def wake_up(self, tags: list[str] | None = None):
1688
        """
1689
1690
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1691

1692
        Args:
1693
1694
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1695
1696
1697
1698
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1699
1700
        """
        self.llm_engine.wake_up(tags)
1701

1702
1703
1704
1705
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1706
            A `MetricSnapshot` instance capturing the current state
1707
1708
1709
1710
1711
1712
1713
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1714
    def _params_to_seq(
1715
        self,
1716
        params: _P | Sequence[_P],
1717
        num_requests: int,
1718
    ) -> Sequence[_P]:
1719
1720
1721
1722
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1723
                    f"and params ({len(params)}) must be the same."
1724
1725
                )

1726
            return params
1727

1728
1729
1730
1731
1732
1733
1734
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1735
1736
1737
1738
1739
1740
1741
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1742
1743
1744
            return lora_request

        return [lora_request] * num_requests
1745

1746
1747
1748
1749
1750
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1751
1752
1753
1754
1755
1756
1757
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
            return priority

        return [0] * num_requests

    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1768
        output_type: type[_O],
1769
1770
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1771
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1772
1773
1774
1775
1776
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]
        seq_priority = self._priority_to_seq(priority, len(prompts))

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_cmpl_one(prompt, tok_kwargs)
                for prompt, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_prompts,
                        use_tqdm=use_tqdm,
                        desc="Rendering prompts",
1795
                    ),
1796
                    seq_tok_kwargs,
1797
                )
1798
            ),
1799
            params=seq_params,
1800
            output_type=output_type,
1801
            use_tqdm=use_tqdm,
1802
1803
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1804
1805
1806
1807
1808
1809
1810
1811
1812
        )

    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1813
        output_type: type[_O],
1814
1815
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1816
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1817
1818
1819
1820
1821
1822
1823
1824
1825
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenization_kwargs=tok_kwargs,
                    mm_processor_kwargs=mm_processor_kwargs,
                )
                for conversation, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_convs,
                        use_tqdm=use_tqdm,
                        desc="Rendering conversations",
                    ),
                    seq_tok_kwargs,
                )
            ),
            params=seq_params,
1860
            output_type=output_type,
1861
1862
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1863
1864
        )

1865
1866
1867
1868
    def _render_and_run_requests(
        self,
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1869
        output_type: type[_O],
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1887
            params=params,
1888
1889
            lora_requests=lora_requests,
            priorities=priorities,
1890
1891
        )

1892
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1893

1894
    def _render_and_add_requests(
1895
        self,
1896
1897
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1898
        *,
1899
1900
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1901
    ) -> list[str]:
1902
        added_request_ids: list[str] = []
1903

1904
        try:
1905
            for i, prompt in enumerate(prompts):
1906
1907
                request_id = self._add_request(
                    prompt,
1908
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
1909
1910
1911
1912
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
1913
                    priority=0 if priorities is None else priorities[i],
1914
1915
1916
1917
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1918
                self.llm_engine.abort_request(added_request_ids, internal=True)
1919
            raise e
1920

1921
1922
        return added_request_ids

1923
    def _add_request(
nunjunj's avatar
nunjunj committed
1924
        self,
1925
        prompt: ProcessorInputs,
1926
1927
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1928
        priority: int = 0,
1929
    ) -> str:
1930
1931
1932
1933
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1934
        request_id = str(next(self.request_counter))
1935

1936
        return self.llm_engine.add_request(
1937
            request_id,
1938
            prompt,
1939
1940
            params,
            lora_request=lora_request,
1941
            priority=priority,
nunjunj's avatar
nunjunj committed
1942
        )
1943

1944
    def _run_engine(
1945
        self,
1946
        output_type: type[_O] | tuple[type[_O], ...],
1947
1948
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1949
    ) -> list[_O]:
1950
1951
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1952
            num_requests = self.llm_engine.get_num_unfinished_requests()
1953
1954
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1955
1956
1957
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1958
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1959
            )
1960

Zhuohan Li's avatar
Zhuohan Li committed
1961
        # Run the engine.
1962
        outputs: list[_O] = []
1963
1964
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1965
1966
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1967
            for output in step_outputs:
1968
                assert isinstance(output, output_type)
1969
                if output.finished:
1970
                    outputs.append(output)  # type: ignore[arg-type]
1971
                    if use_tqdm:
1972
1973
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1974
                            n = len(output.outputs)
1975
                            assert output.prompt_token_ids is not None
1976
                            total_in_toks += len(output.prompt_token_ids) * n
1977
1978
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1979
1980
1981
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1982
1983
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1984
1985
                                f"output: {out_spd:.2f} toks/s"
                            )
1986
                            pbar.update(n)
1987
1988
                        else:
                            pbar.update(1)
1989
1990
                        if pbar.n == num_requests:
                            pbar.refresh()
1991

1992
1993
        if use_tqdm:
            pbar.close()
1994
1995
1996
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1997
        return sorted(outputs, key=lambda x: int(x.request_id))
1998

1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr