llm.py 84.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Callable, Iterable, Sequence
6
from pathlib import Path
7
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, overload
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
44
    ChatTemplateConfig,
45
    ChatTemplateContentFormatOption,
46
    load_chat_template,
47
)
48
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
49
from vllm.entrypoints.pooling.score.utils import (
50
    ScoreData,
51
52
53
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
54
    compute_maxsim_score,
55
    get_score_prompt,
56
    score_data_to_prompts,
57
    validate_score_input,
58
)
59
from vllm.entrypoints.utils import log_non_default_args
60
from vllm.inputs.data import (
61
    DataPrompt,
62
    ProcessorInputs,
63
64
65
66
67
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
68
from vllm.logger import init_logger
69
from vllm.lora.request import LoRARequest
70
from vllm.model_executor.layers.quantization import QuantizationMethods
71
72
73
74
75
76
77
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
78
from vllm.platforms import current_platform
79
from vllm.pooling_params import PoolingParams
80
from vllm.renderers import ChatParams, merge_kwargs
81
82
83
84
85
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
86
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
87
from vllm.tasks import PoolingTask
88
from vllm.tokenizers import TokenizerLike
yhu422's avatar
yhu422 committed
89
from vllm.usage.usage_lib import UsageContext
90
from vllm.utils.counter import Counter
91
from vllm.utils.mistral import is_mistral_tokenizer
92
from vllm.utils.tqdm_utils import maybe_tqdm
93
from vllm.v1.engine import PauseMode
94
from vllm.v1.engine.llm_engine import LLMEngine
95
from vllm.v1.sample.logits_processor import LogitsProcessor
96

97
98
99
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

100
101
logger = init_logger(__name__)

102
103
104
105
106
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
107
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
108
109
_R = TypeVar("_R", default=Any)

110
111

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
117
118
119
120
121
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
122
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
123
124
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
125
126
127
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
128
129
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
130
131
132
133
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
134
        allowed_media_domains: If set, only media URLs that belong to this
135
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
139
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
140
141
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
142
        quantization: The method used to quantize the model weights. Currently,
143
            we support "awq", "gptq", and "fp8" (experimental).
144
145
146
147
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
148
149
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
150
151
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
152
        chat_template: The chat template to apply.
153
154
155
156
157
158
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
159
160
161
162
163
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
164
            compared with using gpu_memory_utilization. Note that
165
166
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
167
168
169
170
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
171
172
173
174
175
176
177
178
179
180
181
182
183
        offload_group_size: Prefetch offloading: Group every N layers
            together. Offload last `offload_num_in_group` layers of each group.
            Default is 0 (disabled).
        offload_num_in_group: Prefetch offloading: Number of layers to
            offload per group. Default is 1.
        offload_prefetch_step: Prefetch offloading: Number of layers to
            prefetch ahead. Higher values hide more latency but use more GPU
            memory. Default is 1.
        offload_params: Prefetch offloading: Set of parameter name segments
            to selectively offload. Only parameters whose names contain one of
            these segments will be offloaded (e.g., {"gate_up_proj", "down_proj"}
            for MLP weights, or {"w13_weight", "w2_weight"} for MoE expert
            weights). If None or empty, all parameters are offloaded.
184
185
186
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
187
        enable_return_routed_experts: Whether to return routed experts.
188
189
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
190
        hf_token: The token to use as HTTP bearer authorization for remote files
191
            . If `True`, will use the token generated when running
192
            `hf auth login` (stored in `~/.cache/huggingface/token`).
193
194
195
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
196
197
198
199
200
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
201
202
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
203
        compilation_config: Either an integer or a dictionary. If it is an
204
            integer, it is used as the mode of compilation optimization. If it
205
            is a dictionary, it can specify the full compilation configuration.
206
207
208
209
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
210
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
211

212
213
    Note:
        This class is intended to be used for offline inference. For online
214
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
215
    """
216
217
218
219

    def __init__(
        self,
        model: str,
220
        *,
221
222
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
223
        tokenizer: str | None = None,
224
        tokenizer_mode: TokenizerMode | str = "auto",
225
        skip_tokenizer_init: bool = False,
226
        trust_remote_code: bool = False,
227
        allowed_local_media_path: str = "",
228
        allowed_media_domains: list[str] | None = None,
229
        tensor_parallel_size: int = 1,
230
        dtype: ModelDType = "auto",
231
232
233
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
234
        chat_template: Path | str | None = None,
235
        seed: int = 0,
236
        gpu_memory_utilization: float = 0.9,
237
        cpu_offload_gb: float = 0,
238
239
240
241
        offload_group_size: int = 0,
        offload_num_in_group: int = 1,
        offload_prefetch_step: int = 1,
        offload_params: set[str] | None = None,
242
        enforce_eager: bool = False,
243
        enable_return_routed_experts: bool = False,
244
        disable_custom_all_reduce: bool = False,
245
246
247
248
249
250
251
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
252
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
253
        attention_config: dict[str, Any] | AttentionConfig | None = None,
254
255
256
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
257
        **kwargs: Any,
258
    ) -> None:
259
        """LLM constructor."""
260

261
262
263
264
265
266
267
268
269
270
271
        if "swap_space" in kwargs:
            kwargs.pop("swap_space")
            import warnings

            warnings.warn(
                "The 'swap_space' parameter is deprecated and ignored. "
                "It will be removed in a future version.",
                DeprecationWarning,
                stacklevel=2,
            )

272
273
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
274

275
276
277
278
279
280
281
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

282
        if "kv_transfer_config" in kwargs and isinstance(
283
284
            kwargs["kv_transfer_config"], dict
        ):
285
            from vllm.config.kv_transfer import KVTransferConfig
286

287
288
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
289
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
290
291
292
293
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
294
295
296
                    raw_config_dict,
                    e,
                )
297
298
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
299
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
300

301
302
303
        if hf_overrides is None:
            hf_overrides = {}

304
305
306
307
308
309
310
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
311

312
313
314
315
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
316
        else:
317
318
319
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
320

321
322
323
324
325
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
326

327
        # warn about single-process data parallel usage.
328
329
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
330
331
332
333
334
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
335
            raise ValueError(
336
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
337
338
339
340
341
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
342
        engine_args = EngineArgs(
343
            model=model,
344
345
            runner=runner,
            convert=convert,
346
            tokenizer=tokenizer,
347
            tokenizer_mode=tokenizer_mode,
348
            skip_tokenizer_init=skip_tokenizer_init,
349
            trust_remote_code=trust_remote_code,
350
            allowed_local_media_path=allowed_local_media_path,
351
            allowed_media_domains=allowed_media_domains,
352
353
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
354
            quantization=quantization,
355
            revision=revision,
356
            tokenizer_revision=tokenizer_revision,
357
358
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
359
            kv_cache_memory_bytes=kv_cache_memory_bytes,
360
            cpu_offload_gb=cpu_offload_gb,
361
362
363
364
            offload_group_size=offload_group_size,
            offload_num_in_group=offload_num_in_group,
            offload_prefetch_step=offload_prefetch_step,
            offload_params=offload_params or set(),
365
            enforce_eager=enforce_eager,
366
            enable_return_routed_experts=enable_return_routed_experts,
367
            disable_custom_all_reduce=disable_custom_all_reduce,
368
            hf_token=hf_token,
369
            hf_overrides=hf_overrides,
370
            mm_processor_kwargs=mm_processor_kwargs,
371
            pooler_config=pooler_config,
372
            structured_outputs_config=structured_outputs_instance,
373
            profiler_config=profiler_config_instance,
374
            attention_config=attention_config_instance,
375
            compilation_config=compilation_config_instance,
376
            logits_processors=logits_processors,
377
378
            **kwargs,
        )
379

380
381
        log_non_default_args(engine_args)

382
        self.llm_engine = LLMEngine.from_engine_args(
383
384
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
385
        self.model_config = self.llm_engine.model_config
386
        self.engine_class = type(self.llm_engine)
387

388
        self.request_counter = Counter()
389
        self.default_sampling_params: dict[str, Any] | None = None
390

391
        supported_tasks = self.llm_engine.get_supported_tasks()
392
        self.supported_tasks = supported_tasks
393
394
395
        self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
        if self.pooling_task is not None:
            logger.info("Supported pooling task: %s", self.pooling_task)
396

397
        self.runner_type = self.model_config.runner_type
398
        self.renderer = self.llm_engine.renderer
399
        self.chat_template = load_chat_template(chat_template)
400
        self.io_processor = self.llm_engine.io_processor
401
        self.input_processor = self.llm_engine.input_processor
402
        self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
403
        self.pooling_io_processors = init_pooling_io_processors(
404
405
406
407
408
            supported_tasks=supported_tasks,
            model_config=self.model_config,
            renderer=self.renderer,
            chat_template_config=self.chat_template_config,
        )
409
410
411
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

412
413
414
415
416
    @classmethod
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLM":
        """Create an LLM instance from EngineArgs."""
        return cls(**vars(engine_args))

417
    def get_tokenizer(self) -> TokenizerLike:
418
        return self.llm_engine.get_tokenizer()
419

420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

437
    def reset_mm_cache(self) -> None:
438
        self.renderer.clear_mm_cache()
439
440
        self.llm_engine.reset_mm_cache()

441
    def get_default_sampling_params(self) -> SamplingParams:
442
        if self.default_sampling_params is None:
443
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
444
445
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
446
447
        return SamplingParams()

448
449
    def generate(
        self,
450
451
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
452
        *,
453
        use_tqdm: bool | Callable[..., tqdm] = True,
454
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
455
        priority: list[int] | None = None,
456
        tokenization_kwargs: dict[str, Any] | None = None,
457
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
458
459
        """Generates the completions for the input prompts.

460
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
461
462
463
464
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
465
            prompts: The prompts to the LLM. You may pass a sequence of prompts
466
                for batch inference. See [PromptType][vllm.inputs.PromptType]
467
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
468
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
469
470
471
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
472
                prompts and it is paired one by one with the prompt.
473
474
475
476
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
477
            lora_request: LoRA request to use for generation, if any.
478
479
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
480
481
482
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
483
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
484
485

        Returns:
486
            A list of `RequestOutput` objects containing the
487
488
            generated completions in the same order as the input prompts.
        """
489
        runner_type = self.model_config.runner_type
490
        if runner_type != "generate":
491
492
493
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
494
495
                "generative model."
            )
496

497
        if sampling_params is None:
498
            sampling_params = self.get_default_sampling_params()
499

500
        return self._run_completion(
501
            prompts=prompts,
502
            params=sampling_params,
503
            output_type=RequestOutput,
504
            use_tqdm=use_tqdm,
505
            lora_request=lora_request,
506
            tokenization_kwargs=tokenization_kwargs,
507
508
            priority=priority,
        )
509

510
511
512
513
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
514
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
536
        runner_type = self.model_config.runner_type
537
538
539
540
541
542
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

543
544
545
546
547
548
549
        return self._add_completion_requests(
            prompts=prompts,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
550
551
        )

552
    @overload
553
554
    def wait_for_completion(
        self,
555
        *,
556
        use_tqdm: bool | Callable[..., tqdm] = True,
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
573
574
575
576
577
578
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
579
            output_type: The expected output type, defaults to RequestOutput.
580
581
582
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
583
            A list of output objects for all completed requests.
584
        """
585
586
587
588
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
589

Cyrus Leung's avatar
Cyrus Leung committed
590
    def _resolve_mm_lora(
591
        self,
592
        prompt: ProcessorInputs,
593
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
594
595
596
597
598
599
600
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
601
602
            return lora_request

603
604
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
605
606
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
607

608
609
610
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
611
612
613
614
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
615
616
                intersection,
            )
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
632
633
                    "lora_request as we only apply one LoRARequest per prompt"
                )
634
635
636
637
638
639
640
641
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

642
643
    def collective_rpc(
        self,
644
645
        method: str | Callable[..., _R],
        timeout: float | None = None,
646
        args: tuple = (),
647
        kwargs: dict[str, Any] | None = None,
648
    ) -> list[_R]:
649
650
651
652
653
654
655
656
657
658
659
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
660
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
661
662
663
664
665
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
666

667
668
669
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
670
        """
671
672

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
673
674

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
675
        """
676
677
        Run a function directly on the model inside each worker,
        returning the result for each of them.
678
679
680
681
682
683

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
684
        """
685
        return self.llm_engine.apply_model(func)
686

687
688
    def beam_search(
        self,
689
        prompts: list[TokensPrompt | TextPrompt],
690
        params: BeamSearchParams,
691
        lora_request: list[LoRARequest] | LoRARequest | None = None,
692
        use_tqdm: bool = False,
693
        concurrency_limit: int | None = None,
694
    ) -> list[BeamSearchOutput]:
695
696
697
698
699
700
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
701
            params: The beam search parameters.
702
            lora_request: LoRA request to use for generation, if any.
703
            use_tqdm: Whether to use tqdm to display the progress bar.
704
705
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
706
        """
707
708
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
709
710
711
712
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
713
714
        length_penalty = params.length_penalty

715
716
717
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
718

719
720
        engine_prompts = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
721

722
723
724
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
725
726
                "Disabling progress bar."
            )
727
728
729
            use_tqdm = False

        if concurrency_limit is None:
730
            concurrency_limit = len(engine_prompts)
731

732
733
734
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
735
        sampling_params = SamplingParams(
736
737
738
739
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
740
        )
741
        instances: list[BeamSearchInstance] = []
742

743
744
745
746
747
        for lora_req, prompt in zip(lora_requests, engine_prompts):
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
748

749
            instances.append(
750
                BeamSearchInstance(
751
                    prompt,
752
753
                    lora_request=lora_req,
                    logprobs=None,
754
755
                ),
            )
756

757
        for prompt_start in range(0, len(instances), concurrency_limit):
758
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
759
760
761

            token_iter = range(max_tokens)
            if use_tqdm:
762
763
764
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
765
766
767
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
768
769
                    "reflect instance-level progress."
                )
770
771
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
772
773
                    sum((instance.beams for instance in instances_batch), [])
                )
774
775
                pos = [0] + list(
                    itertools.accumulate(
776
777
778
                        len(instance.beams) for instance in instances_batch
                    )
                )
779
                instance_start_and_end: list[tuple[int, int]] = list(
780
781
                    zip(pos[:-1], pos[1:])
                )
782
783
784
785
786
787

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
788
                output = self._render_and_run_requests(
789
790
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
791
                    output_type=RequestOutput,
792
                    lora_requests=[beam.lora_request for beam in all_beams],
793
794
                    use_tqdm=False,
                )
795

796
797
798
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
799
800
801
802
803
804
805
806
807
808
809
810
811
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
812
                                    current_beam.orig_prompt,
813
                                    tokens=current_beam.tokens + [token_id],
814
                                    logprobs=current_beam.logprobs + [logprobs],
815
                                    lora_request=current_beam.lora_request,
816
817
818
819
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

820
                                if token_id == eos_token_id and not ignore_eos:
821
822
823
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
824
825
826
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
827
                    instance.beams = sorted_beams[:beam_width]
828
829
830
831

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
832
833
834
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
835
836
837
838
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
839

840
841
842
843
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

844
    def _preprocess_cmpl(
845
        self,
846
        prompts: Sequence[PromptType],
847
        tokenization_kwargs: dict[str, Any] | None = None,
848
    ) -> Sequence[ProcessorInputs]:
849
850
851
852
853
854
855
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
856
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
857
        """
858
        renderer = self.renderer
859
860
        model_config = self.model_config

861
862
863
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
864
865
866
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
867

868
        return renderer.render_cmpl(parsed_prompts, tok_params)
869

870
871
872
873
874
875
876
877
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_prompt

878
879
    def _preprocess_chat(
        self,
880
        conversations: Sequence[list[ChatCompletionMessageParam]],
881
        chat_template: str | None = None,
882
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
883
        chat_template_kwargs: dict[str, Any] | None = None,
884
        add_generation_prompt: bool = True,
885
        continue_final_message: bool = False,
886
        tools: list[dict[str, Any]] | None = None,
887
        tokenization_kwargs: dict[str, Any] | None = None,
888
        mm_processor_kwargs: dict[str, Any] | None = None,
889
    ) -> Sequence[ProcessorInputs]:
nunjunj's avatar
nunjunj committed
890
        """
891
892
893
894
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
895
896

        Returns:
897
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
898
        """
899
        renderer = self.renderer
900

901
902
903
904
905
906
907
908
909
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
910
                    tokenize=is_mistral_tokenizer(renderer.tokenizer),
911
912
913
                ),
            ),
        )
914
915
916
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
917

918
919
920
921
922
923
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
924

925
        return engine_prompts
926

927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_chat(
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        return engine_prompt

953
954
    def chat(
        self,
955
        messages: list[ChatCompletionMessageParam]
956
957
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
958
        use_tqdm: bool | Callable[..., tqdm] = True,
959
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
960
        chat_template: str | None = None,
961
962
963
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
964
965
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
966
        tokenization_kwargs: dict[str, Any] | None = None,
967
        mm_processor_kwargs: dict[str, Any] | None = None,
968
969
970
971
972
973
974
975
976
977
978
979
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
980
            messages: A sequence of conversations or a single conversation.
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
1012
1013
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
1014
1015
1016
1017
1018

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

1031
        return self._run_chat(
1032
1033
            messages=messages,
            params=sampling_params,
1034
            output_type=RequestOutput,
1035
1036
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1037
1038
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1039
            chat_template_kwargs=chat_template_kwargs,
1040
1041
1042
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1043
            tokenization_kwargs=tokenization_kwargs,
1044
1045
1046
            mm_processor_kwargs=mm_processor_kwargs,
        )

1047
1048
    def encode(
        self,
1049
1050
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1051
        *,
1052
1053
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1054
        pooling_task: PoolingTask | None = None,
1055
        tokenization_kwargs: dict[str, Any] | None = None,
1056
    ) -> list[PoolingRequestOutput]:
1057
1058
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1059

1060
        This class automatically batches the given prompts, considering
1061
1062
1063
1064
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1065
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1066
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1067
                for more details about the format of each prompt.
1068
1069
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1070
1071
1072
1073
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1074
            lora_request: LoRA request to use for generation, if any.
1075
            pooling_task: Override the pooling task to use.
1076
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1077
1078

        Returns:
1079
            A list of `PoolingRequestOutput` objects containing the
1080
            pooled hidden states in the same order as the input prompts.
1081
        """
1082

1083
        self._verify_pooling_task(pooling_task)
1084

1085
        if isinstance(prompts, dict) and "data" in prompts:
1086
1087
1088
1089
1090
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1091
1092
                    "offline inference example for more details."
                )
1093
1094

            # Validate the request data is valid for the loaded plugin
1095
1096
1097
1098
1099
1100
1101
1102
1103
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1104
1105
1106

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1107
            prompts_seq = prompt_to_seq(prompts)
1108

1109
1110
1111
1112
1113
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1114
                )
1115
1116
1117
1118
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143

            outputs = self._run_completion(
                prompts=prompts_seq,
                params=params_seq,
                output_type=PoolingRequestOutput,
                use_tqdm=use_tqdm,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )

            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(outputs)

            return [
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
                    prompt_token_ids=[],
                    finished=True,
                )
            ]
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1160

1161
1162
            if pooling_task in self.pooling_io_processors:
                io_processor = self.pooling_io_processors[pooling_task]
1163
1164
1165
1166
1167
1168
1169
                processor_inputs = io_processor.pre_process_offline(
                    prompts_seq, tokenization_kwargs
                )
                seq_lora_requests = self._lora_request_to_seq(
                    lora_request, len(prompts_seq)
                )
                seq_priority = self._priority_to_seq(None, len(prompts))
1170

1171
1172
1173
1174
1175
                self._render_and_add_requests(
                    prompts=processor_inputs,
                    params=params_seq,
                    lora_requests=seq_lora_requests,
                    priorities=seq_priority,
1176
                )
1177

1178
1179
1180
                outputs = self._run_engine(
                    use_tqdm=use_tqdm, output_type=PoolingRequestOutput
                )
1181
                outputs = io_processor.post_process_offline(outputs)
1182
1183
1184
1185
1186
1187
1188
1189
1190
            else:
                outputs = self._run_completion(
                    prompts=prompts_seq,
                    params=params_seq,
                    output_type=PoolingRequestOutput,
                    use_tqdm=use_tqdm,
                    lora_request=lora_request,
                    tokenization_kwargs=tokenization_kwargs,
                )
1191
        return outputs
1192

1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
    def _verify_pooling_task(self, pooling_task: PoolingTask | None):
        if self.runner_type != "pooling":
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model."
            )

        if pooling_task is None:
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )

        if (
            pooling_task in ("embed", "token_embed")
            and pooling_task not in self.supported_tasks
        ):
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`."
            )

        if (
            pooling_task in ("classify", "token_classify")
            and pooling_task not in self.supported_tasks
        ):
            raise ValueError(
                "Classification API is not supported by this model. "
                "Try converting the model using `--convert classify`."
            )

        # plugin task uses io_processor.parse_request to verify inputs
        if pooling_task != "plugin" and pooling_task != self.pooling_task:
            if pooling_task not in self.supported_tasks:
                raise ValueError(
                    f"Unsupported task: {pooling_task!r} "
                    f"Supported tasks: {self.supported_tasks}"
                )
            else:
                logger.warning_once(
                    "Pooling multitask support is deprecated and will "
                    "be removed in v0.20. When the default pooling task is "
                    "not what you want, you need to manually specify it "
                    'via PoolerConfig(task="%s"). ',
                    pooling_task,
                )

1252
1253
    def embed(
        self,
1254
        prompts: PromptType | Sequence[PromptType],
1255
        *,
1256
1257
1258
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1259
        tokenization_kwargs: dict[str, Any] | None = None,
1260
    ) -> list[EmbeddingRequestOutput]:
1261
1262
1263
1264
1265
1266
1267
1268
1269
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1270
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1271
                for more details about the format of each prompt.
1272
1273
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1274
1275
1276
1277
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1278
            lora_request: LoRA request to use for generation, if any.
1279
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1280
1281

        Returns:
1282
            A list of `EmbeddingRequestOutput` objects containing the
1283
1284
1285
            embedding vectors in the same order as the input prompts.
        """

1286
1287
1288
1289
1290
1291
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1292
            tokenization_kwargs=tokenization_kwargs,
1293
        )
1294
1295
1296
1297
1298

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1299
        prompts: PromptType | Sequence[PromptType],
1300
        *,
1301
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1302
        use_tqdm: bool | Callable[..., tqdm] = True,
1303
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1304
        tokenization_kwargs: dict[str, Any] | None = None,
1305
    ) -> list[ClassificationRequestOutput]:
1306
1307
1308
1309
1310
1311
1312
1313
1314
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1315
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1316
                for more details about the format of each prompt.
1317
1318
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1319
1320
1321
1322
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1323
            lora_request: LoRA request to use for generation, if any.
1324
1325
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1326
        Returns:
1327
            A list of `ClassificationRequestOutput` objects containing the
1328
1329
1330
            embedding vectors in the same order as the input prompts.
        """

1331
1332
1333
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1334
            pooling_params=pooling_params,
1335
1336
            lora_request=lora_request,
            pooling_task="classify",
1337
            tokenization_kwargs=tokenization_kwargs,
1338
        )
1339
1340
1341

        return [ClassificationRequestOutput.from_base(item) for item in items]

1342
1343
    def reward(
        self,
1344
        prompts: PromptType | Sequence[PromptType],
1345
1346
        /,
        *,
1347
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1348
1349
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1350
        tokenization_kwargs: dict[str, Any] | None = None,
1351
1352
1353
1354
1355
1356
1357
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1358
                for more details about the format of each prompt.
1359
1360
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1361
1362
1363
1364
1365
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1366
1367
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1368
1369
1370
1371
1372
1373
1374
1375
1376
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """
        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
1377
            pooling_task="token_classify",
1378
            tokenization_kwargs=tokenization_kwargs,
1379
1380
        )

1381
1382
    def _embedding_score(
        self,
1383
1384
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1385
1386
1387
1388
1389
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1390
    ) -> list[ScoringRequestOutput]:
1391
1392
        tokenizer = self.get_tokenizer()

1393
1394
1395
1396
1397
1398
1399
1400
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1401
        encoded_output = self.encode(
1402
            input_texts,
1403
1404
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1405
            pooling_params=pooling_params,
1406
            pooling_task="embed",
1407
            tokenization_kwargs=tokenization_kwargs,
1408
        )
1409

1410
1411
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1412
1413
1414
1415

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1416
        scores = _cosine_similarity(
1417
1418
1419
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1420
        )
1421

1422
        return [ScoringRequestOutput.from_base(item) for item in scores]
1423

1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

1444
1445
1446
1447
        # Convert ScoreData to PromptType (handles both text and multimodal)
        model_config = self.model_config
        prompts_1 = score_data_to_prompts(data_1, "query", model_config)
        prompts_2 = score_data_to_prompts(data_2, "document", model_config)
1448

1449
1450
        encoded_output: list[PoolingRequestOutput] = self.encode(
            prompts_1 + prompts_2,
1451
1452
1453
1454
1455
1456
1457
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

1458
1459
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[: len(prompts_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(prompts_1) :]
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

1490
        return [ScoringRequestOutput.from_base(item) for item in scores]
1491

1492
1493
    def _cross_encoding_score(
        self,
1494
1495
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1496
1497
1498
1499
1500
1501
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1502
    ) -> list[ScoringRequestOutput]:
1503
        model_config = self.model_config
1504
        tokenizer = self.get_tokenizer()
1505

1506
        if is_mistral_tokenizer(tokenizer):
1507
            raise ValueError("Score API is not supported for Mistral tokenizer")
1508

1509
1510
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1511

1512
        if pooling_params is None:
1513
            pooling_params = PoolingParams(task="classify")
1514
        elif pooling_params.task is None:
1515
            pooling_params.task = "classify"
1516

1517
        pooling_params_list = list[PoolingParams]()
1518

1519
        prompts = list[PromptType]()
1520

1521
1522
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1523
1524
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1525
                model_config=model_config,
1526
1527
1528
1529
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1530
                score_template=score_template,
1531
1532
            )

1533
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1534
1535
1536
1537
1538
1539
1540
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1541
            prompts.append(engine_prompt)
1542

1543
        outputs = self._run_completion(
1544
            prompts=prompts,
1545
            params=pooling_params_list,
1546
            output_type=PoolingRequestOutput,
1547
            use_tqdm=use_tqdm,
1548
1549
1550
            lora_request=lora_request,
        )

1551
        return [ScoringRequestOutput.from_base(item) for item in outputs]
1552

1553
1554
    def score(
        self,
1555
1556
1557
1558
1559
1560
1561
1562
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1563
        /,
1564
        *,
1565
1566
1567
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1568
        tokenization_kwargs: dict[str, Any] | None = None,
1569
        chat_template: str | None = None,
1570
    ) -> list[ScoringRequestOutput]:
1571
1572
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1573

1574
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1575
1576
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1577
        The input pairs are used to build a list of prompts for the
1578
1579
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1580
1581
1582
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1583
        appropriate multi-modal models. For multi-modal inputs, ensure the
1584
        prompt structure matches the model's expected input format.
1585
1586

        Args:
1587
1588
1589
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1590
                the `data_2` list.
1591
            data_2: The data to pair with the query to form the input to
1592
                the LLM. Can be text or multi-modal data. See [PromptType]
1593
                [vllm.inputs.PromptType] for more details about the format of
1594
                each prompt.
1595
1596
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1597
1598
1599
1600
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1601
            lora_request: LoRA request to use for generation, if any.
1602
1603
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1604
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1605
        Returns:
1606
            A list of `ScoringRequestOutput` objects containing the
1607
1608
            generated scores in the same order as the input prompts.
        """
1609
        model_config = self.model_config
1610

1611
        runner_type = model_config.runner_type
1612
        if runner_type != "pooling":
1613
1614
1615
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1616
1617
                "pooling model."
            )
1618

1619
        supported_tasks = self.supported_tasks
1620
1621
1622
1623
        score_type = self.model_config.score_type
        is_late_interaction = score_type == "late-interaction"
        is_cross_encoder = score_type == "cross-encoder"

1624
1625
1626
1627
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1628
1629
1630
1631
1632
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1633

1634
        if is_cross_encoder and getattr(model_config.hf_config, "num_labels", 0) != 1:
1635
            raise ValueError("Score API is only enabled for num_labels == 1.")
1636

1637
        if not is_cross_encoder and chat_template is not None:
1638
1639
1640
1641
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1642
1643
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1644

1645
1646
1647
1648
1649
1650
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1651

1652
1653
1654
1655
        renderer = self.renderer
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
1656
1657
        encode_kwargs = tok_params.get_encode_kwargs()

1658
        if is_cross_encoder:
1659
            return self._cross_encoding_score(
1660
1661
                score_data_1,
                score_data_2,
1662
1663
1664
1665
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1666
                score_template=chat_template,
1667
            )
1668
1669
1670
1671
1672
1673
1674
1675
1676
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1677
        else:
1678
            return self._embedding_score(
1679
1680
                score_data_1,
                score_data_2,
1681
1682
1683
1684
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1685
            )
1686

1687
1688
1689
1690
1691
1692
1693
1694
1695
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1696
1697
1698
1699

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1700
1701
1702
1703
1704
1705
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1706

1707
    def sleep(self, level: int = 1, mode: PauseMode = "abort"):
1708
1709
1710
1711
1712
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1713
        Args:
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1727
1728
            mode: How to handle any existing requests, can be "abort", "wait",
                or "keep".
1729
        """
1730
        self.llm_engine.sleep(level=level, mode=mode)
1731

1732
    def wake_up(self, tags: list[str] | None = None):
1733
        """
1734
1735
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1736

1737
        Args:
1738
1739
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1740
1741
1742
1743
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1744
1745
        """
        self.llm_engine.wake_up(tags)
1746

1747
1748
1749
1750
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1751
            A `MetricSnapshot` instance capturing the current state
1752
1753
1754
1755
1756
1757
1758
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1759
    def _params_to_seq(
1760
        self,
1761
        params: _P | Sequence[_P],
1762
        num_requests: int,
1763
    ) -> Sequence[_P]:
1764
1765
1766
1767
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1768
                    f"and params ({len(params)}) must be the same."
1769
1770
                )

1771
            return params
1772

1773
1774
1775
1776
1777
1778
1779
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1780
1781
1782
1783
1784
1785
1786
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1787
1788
1789
            return lora_request

        return [lora_request] * num_requests
1790

1791
1792
1793
1794
1795
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1796
1797
1798
1799
1800
1801
1802
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1803
1804
1805
1806
            return priority

        return [0] * num_requests

1807
    def _add_completion_requests(
1808
1809
1810
1811
1812
1813
1814
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1815
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1816
1817
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
1818
    ) -> list[str]:
1819
1820
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1821
1822
1823
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_priority = self._priority_to_seq(priority, len(prompts))

1824
        return self._render_and_add_requests(
1825
            prompts=(
1826
1827
1828
1829
1830
                self._preprocess_cmpl_one(prompt, tokenization_kwargs)
                for prompt in maybe_tqdm(
                    seq_prompts,
                    use_tqdm=use_tqdm,
                    desc="Rendering prompts",
1831
                )
1832
            ),
1833
            params=seq_params,
1834
1835
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1836
1837
        )

1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        output_type: type[_O],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        self._add_completion_requests(
            prompts=prompts,
            params=params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
        )
        return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)

1861
1862
1863
1864
1865
1866
1867
    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1868
        output_type: type[_O],
1869
1870
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1871
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1872
1873
1874
1875
1876
1877
1878
1879
1880
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
1895
                    tokenization_kwargs=tokenization_kwargs,
1896
1897
                    mm_processor_kwargs=mm_processor_kwargs,
                )
1898
1899
1900
1901
                for conversation in maybe_tqdm(
                    seq_convs,
                    use_tqdm=use_tqdm,
                    desc="Rendering conversations",
1902
1903
1904
                )
            ),
            params=seq_params,
1905
            output_type=output_type,
1906
1907
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1908
1909
        )

1910
1911
1912
1913
    def _render_and_run_requests(
        self,
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1914
        output_type: type[_O],
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1932
            params=params,
1933
1934
            lora_requests=lora_requests,
            priorities=priorities,
1935
1936
        )

1937
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1938

1939
    def _render_and_add_requests(
1940
        self,
1941
1942
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1943
        *,
1944
1945
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1946
    ) -> list[str]:
1947
        added_request_ids: list[str] = []
1948

1949
        try:
1950
            for i, prompt in enumerate(prompts):
1951
1952
                request_id = self._add_request(
                    prompt,
1953
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
1954
1955
1956
1957
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
1958
                    priority=0 if priorities is None else priorities[i],
1959
1960
1961
1962
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1963
                self.llm_engine.abort_request(added_request_ids, internal=True)
1964
            raise e
1965

1966
1967
        return added_request_ids

1968
    def _add_request(
nunjunj's avatar
nunjunj committed
1969
        self,
1970
        prompt: ProcessorInputs,
1971
1972
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1973
        priority: int = 0,
1974
    ) -> str:
1975
1976
1977
1978
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1979
        request_id = str(next(self.request_counter))
1980

1981
        return self.llm_engine.add_request(
1982
            request_id,
1983
            prompt,
1984
1985
            params,
            lora_request=lora_request,
1986
            priority=priority,
nunjunj's avatar
nunjunj committed
1987
        )
1988

1989
    def _run_engine(
1990
        self,
1991
        output_type: type[_O] | tuple[type[_O], ...],
1992
1993
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1994
    ) -> list[_O]:
1995
1996
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1997
            num_requests = self.llm_engine.get_num_unfinished_requests()
1998
1999
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
2000
2001
2002
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
2003
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
2004
            )
2005

Zhuohan Li's avatar
Zhuohan Li committed
2006
        # Run the engine.
2007
        outputs: list[_O] = []
2008
2009
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
2010
2011
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
2012
            for output in step_outputs:
2013
                assert isinstance(output, output_type)
2014
                if output.finished:
2015
                    outputs.append(output)  # type: ignore[arg-type]
2016
                    if use_tqdm:
2017
2018
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
2019
                            n = len(output.outputs)
2020
                            assert output.prompt_token_ids is not None
2021
                            total_in_toks += len(output.prompt_token_ids) * n
2022
2023
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
2024
2025
2026
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
2027
2028
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
2029
2030
                                f"output: {out_spd:.2f} toks/s"
                            )
2031
                            pbar.update(n)
2032
2033
                        else:
                            pbar.update(1)
2034
2035
                        if pbar.n == num_requests:
                            pbar.refresh()
2036

2037
2038
        if use_tqdm:
            pbar.close()
2039
2040
2041
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
2042
        return sorted(outputs, key=lambda x: int(x.request_id))
2043

2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr