llm.py 83.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Callable, Iterable, Sequence
6
from pathlib import Path
7
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, overload
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
44
    ChatTemplateConfig,
45
    ChatTemplateContentFormatOption,
46
    load_chat_template,
47
)
48
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
49
from vllm.entrypoints.pooling.score.utils import (
50
    ScoreData,
51
52
53
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
54
    compute_maxsim_score,
55
    get_score_prompt,
56
    score_data_to_prompts,
57
    validate_score_input,
58
)
59
from vllm.entrypoints.utils import log_non_default_args
60
from vllm.inputs.data import (
61
    DataPrompt,
62
    ProcessorInputs,
63
64
65
66
67
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
68
from vllm.logger import init_logger
69
from vllm.lora.request import LoRARequest
70
from vllm.model_executor.layers.quantization import QuantizationMethods
71
72
73
74
75
76
77
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
78
from vllm.platforms import current_platform
79
from vllm.pooling_params import PoolingParams
80
from vllm.renderers import ChatParams, merge_kwargs
81
82
83
84
85
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
86
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
87
from vllm.tasks import PoolingTask
88
from vllm.tokenizers import TokenizerLike
yhu422's avatar
yhu422 committed
89
from vllm.usage.usage_lib import UsageContext
90
from vllm.utils.counter import Counter
91
from vllm.utils.mistral import is_mistral_tokenizer
92
from vllm.utils.tqdm_utils import maybe_tqdm
93
from vllm.v1.engine import PauseMode
94
from vllm.v1.engine.llm_engine import LLMEngine
95
from vllm.v1.sample.logits_processor import LogitsProcessor
96

97
98
99
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

100
101
logger = init_logger(__name__)

102
103
104
105
106
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
107
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
108
109
_R = TypeVar("_R", default=Any)

110
111

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
112
113
114
115
116
117
118
119
120
121
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
122
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
123
124
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
125
126
127
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
128
129
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
130
131
132
133
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
134
        allowed_media_domains: If set, only media URLs that belong to this
135
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
139
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
140
141
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
142
        quantization: The method used to quantize the model weights. Currently,
143
            we support "awq", "gptq", and "fp8" (experimental).
144
145
146
147
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
148
149
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
150
151
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
152
        chat_template: The chat template to apply.
153
154
155
156
157
158
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
159
160
161
162
163
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
164
            compared with using gpu_memory_utilization. Note that
165
166
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
167
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
168
169
170
171
172
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
173
174
175
176
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
177
178
179
180
181
182
183
184
185
186
187
188
189
        offload_group_size: Prefetch offloading: Group every N layers
            together. Offload last `offload_num_in_group` layers of each group.
            Default is 0 (disabled).
        offload_num_in_group: Prefetch offloading: Number of layers to
            offload per group. Default is 1.
        offload_prefetch_step: Prefetch offloading: Number of layers to
            prefetch ahead. Higher values hide more latency but use more GPU
            memory. Default is 1.
        offload_params: Prefetch offloading: Set of parameter name segments
            to selectively offload. Only parameters whose names contain one of
            these segments will be offloaded (e.g., {"gate_up_proj", "down_proj"}
            for MLP weights, or {"w13_weight", "w2_weight"} for MoE expert
            weights). If None or empty, all parameters are offloaded.
190
191
192
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
193
        enable_return_routed_experts: Whether to return routed experts.
194
195
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
196
        hf_token: The token to use as HTTP bearer authorization for remote files
197
            . If `True`, will use the token generated when running
198
            `hf auth login` (stored in `~/.cache/huggingface/token`).
199
200
201
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
202
203
204
205
206
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
207
208
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
209
        compilation_config: Either an integer or a dictionary. If it is an
210
            integer, it is used as the mode of compilation optimization. If it
211
            is a dictionary, it can specify the full compilation configuration.
212
213
214
215
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
216
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
217

218
219
    Note:
        This class is intended to be used for offline inference. For online
220
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
221
    """
222
223
224
225

    def __init__(
        self,
        model: str,
226
        *,
227
228
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
229
        tokenizer: str | None = None,
230
        tokenizer_mode: TokenizerMode | str = "auto",
231
        skip_tokenizer_init: bool = False,
232
        trust_remote_code: bool = False,
233
        allowed_local_media_path: str = "",
234
        allowed_media_domains: list[str] | None = None,
235
        tensor_parallel_size: int = 1,
236
        dtype: ModelDType = "auto",
237
238
239
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
240
        chat_template: Path | str | None = None,
241
        seed: int = 0,
242
        gpu_memory_utilization: float = 0.9,
243
        swap_space: float = 4,
244
        cpu_offload_gb: float = 0,
245
246
247
248
        offload_group_size: int = 0,
        offload_num_in_group: int = 1,
        offload_prefetch_step: int = 1,
        offload_params: set[str] | None = None,
249
        enforce_eager: bool = False,
250
        enable_return_routed_experts: bool = False,
251
        disable_custom_all_reduce: bool = False,
252
253
254
255
256
257
258
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
259
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
260
        attention_config: dict[str, Any] | AttentionConfig | None = None,
261
262
263
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
264
        **kwargs: Any,
265
    ) -> None:
266
        """LLM constructor."""
267

268
269
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
270

271
272
273
274
275
276
277
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

278
        if "kv_transfer_config" in kwargs and isinstance(
279
280
            kwargs["kv_transfer_config"], dict
        ):
281
            from vllm.config.kv_transfer import KVTransferConfig
282

283
284
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
285
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
286
287
288
289
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
290
291
292
                    raw_config_dict,
                    e,
                )
293
294
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
295
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
296

297
298
299
        if hf_overrides is None:
            hf_overrides = {}

300
301
302
303
304
305
306
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
307

308
309
310
311
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
312
        else:
313
314
315
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
316

317
318
319
320
321
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
322

323
        # warn about single-process data parallel usage.
324
325
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
326
327
328
329
330
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
331
            raise ValueError(
332
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
333
334
335
336
337
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
338
        engine_args = EngineArgs(
339
            model=model,
340
341
            runner=runner,
            convert=convert,
342
            tokenizer=tokenizer,
343
            tokenizer_mode=tokenizer_mode,
344
            skip_tokenizer_init=skip_tokenizer_init,
345
            trust_remote_code=trust_remote_code,
346
            allowed_local_media_path=allowed_local_media_path,
347
            allowed_media_domains=allowed_media_domains,
348
349
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
350
            quantization=quantization,
351
            revision=revision,
352
            tokenizer_revision=tokenizer_revision,
353
354
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
355
            kv_cache_memory_bytes=kv_cache_memory_bytes,
356
            swap_space=swap_space,
357
            cpu_offload_gb=cpu_offload_gb,
358
359
360
361
            offload_group_size=offload_group_size,
            offload_num_in_group=offload_num_in_group,
            offload_prefetch_step=offload_prefetch_step,
            offload_params=offload_params or set(),
362
            enforce_eager=enforce_eager,
363
            enable_return_routed_experts=enable_return_routed_experts,
364
            disable_custom_all_reduce=disable_custom_all_reduce,
365
            hf_token=hf_token,
366
            hf_overrides=hf_overrides,
367
            mm_processor_kwargs=mm_processor_kwargs,
368
            pooler_config=pooler_config,
369
            structured_outputs_config=structured_outputs_instance,
370
            profiler_config=profiler_config_instance,
371
            attention_config=attention_config_instance,
372
            compilation_config=compilation_config_instance,
373
            logits_processors=logits_processors,
374
375
            **kwargs,
        )
376

377
378
        log_non_default_args(engine_args)

379
        self.llm_engine = LLMEngine.from_engine_args(
380
381
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
382
        self.engine_class = type(self.llm_engine)
383

384
        self.request_counter = Counter()
385
        self.default_sampling_params: dict[str, Any] | None = None
386

387
388
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
389
390
        self.supported_tasks = supported_tasks

391
        self.model_config = self.llm_engine.model_config
392
        self.renderer = self.llm_engine.renderer
393
        self.chat_template = load_chat_template(chat_template)
394
        self.io_processor = self.llm_engine.io_processor
395
        self.input_processor = self.llm_engine.input_processor
396
397
398
399
400
401
402
        self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
        self.init_pooling_io_processors = init_pooling_io_processors(
            supported_tasks=supported_tasks,
            model_config=self.model_config,
            renderer=self.renderer,
            chat_template_config=self.chat_template_config,
        )
403
404
405
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

406
    def get_tokenizer(self) -> TokenizerLike:
407
        return self.llm_engine.get_tokenizer()
408

409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

426
    def reset_mm_cache(self) -> None:
427
        self.renderer.clear_mm_cache()
428
429
        self.llm_engine.reset_mm_cache()

430
    def get_default_sampling_params(self) -> SamplingParams:
431
        if self.default_sampling_params is None:
432
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
433
434
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
435
436
        return SamplingParams()

437
438
    def generate(
        self,
439
440
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
441
        *,
442
        use_tqdm: bool | Callable[..., tqdm] = True,
443
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
444
        priority: list[int] | None = None,
445
        tokenization_kwargs: dict[str, Any] | None = None,
446
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448
        """Generates the completions for the input prompts.

449
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
450
451
452
453
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
454
            prompts: The prompts to the LLM. You may pass a sequence of prompts
455
                for batch inference. See [PromptType][vllm.inputs.PromptType]
456
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
457
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
458
459
460
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
461
                prompts and it is paired one by one with the prompt.
462
463
464
465
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
466
            lora_request: LoRA request to use for generation, if any.
467
468
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
469
470
471
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
472
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
473
474

        Returns:
475
            A list of `RequestOutput` objects containing the
476
477
            generated completions in the same order as the input prompts.
        """
478
        runner_type = self.model_config.runner_type
479
        if runner_type != "generate":
480
481
482
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
483
484
                "generative model."
            )
485

486
        if sampling_params is None:
487
            sampling_params = self.get_default_sampling_params()
488

489
        return self._run_completion(
490
            prompts=prompts,
491
            params=sampling_params,
492
            output_type=RequestOutput,
493
            use_tqdm=use_tqdm,
494
            lora_request=lora_request,
495
            tokenization_kwargs=tokenization_kwargs,
496
497
            priority=priority,
        )
498

499
500
501
502
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
503
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
525
        runner_type = self.model_config.runner_type
526
527
528
529
530
531
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

532
533
534
535
536
537
538
        return self._add_completion_requests(
            prompts=prompts,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
539
540
        )

541
    @overload
542
543
    def wait_for_completion(
        self,
544
        *,
545
        use_tqdm: bool | Callable[..., tqdm] = True,
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
562
563
564
565
566
567
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
568
            output_type: The expected output type, defaults to RequestOutput.
569
570
571
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
572
            A list of output objects for all completed requests.
573
        """
574
575
576
577
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
578

Cyrus Leung's avatar
Cyrus Leung committed
579
    def _resolve_mm_lora(
580
        self,
581
        prompt: ProcessorInputs,
582
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
583
584
585
586
587
588
589
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
590
591
            return lora_request

592
593
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
594
595
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
596

597
598
599
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
600
601
602
603
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
604
605
                intersection,
            )
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
621
622
                    "lora_request as we only apply one LoRARequest per prompt"
                )
623
624
625
626
627
628
629
630
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

631
632
    def collective_rpc(
        self,
633
634
        method: str | Callable[..., _R],
        timeout: float | None = None,
635
        args: tuple = (),
636
        kwargs: dict[str, Any] | None = None,
637
    ) -> list[_R]:
638
639
640
641
642
643
644
645
646
647
648
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
649
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
650
651
652
653
654
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
655

656
657
658
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
659
        """
660
661

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
662
663

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
664
        """
665
666
        Run a function directly on the model inside each worker,
        returning the result for each of them.
667
668
669
670
671
672

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
673
        """
674
        return self.llm_engine.apply_model(func)
675

676
677
    def beam_search(
        self,
678
        prompts: list[TokensPrompt | TextPrompt],
679
        params: BeamSearchParams,
680
        lora_request: list[LoRARequest] | LoRARequest | None = None,
681
        use_tqdm: bool = False,
682
        concurrency_limit: int | None = None,
683
    ) -> list[BeamSearchOutput]:
684
685
686
687
688
689
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
690
            params: The beam search parameters.
691
            lora_request: LoRA request to use for generation, if any.
692
            use_tqdm: Whether to use tqdm to display the progress bar.
693
694
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
695
        """
696
697
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
698
699
700
701
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
702
703
        length_penalty = params.length_penalty

704
705
706
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
707

708
709
        engine_prompts = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
710

711
712
713
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
714
715
                "Disabling progress bar."
            )
716
717
718
            use_tqdm = False

        if concurrency_limit is None:
719
            concurrency_limit = len(engine_prompts)
720

721
722
723
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
724
        sampling_params = SamplingParams(
725
726
727
728
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
729
        )
730
        instances: list[BeamSearchInstance] = []
731

732
733
734
735
736
737
738
739
740
        for lora_req, prompt in zip(lora_requests, engine_prompts):
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
            if prompt["type"] == "enc_dec":
                raise NotImplementedError(
                    "Encoder-decoder prompt not supported for beam search"
                )
741

742
            instances.append(
743
                BeamSearchInstance(
744
                    prompt,
745
746
                    lora_request=lora_req,
                    logprobs=None,
747
748
                ),
            )
749

750
        for prompt_start in range(0, len(instances), concurrency_limit):
751
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
752
753
754

            token_iter = range(max_tokens)
            if use_tqdm:
755
756
757
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
758
759
760
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
761
762
                    "reflect instance-level progress."
                )
763
764
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
765
766
                    sum((instance.beams for instance in instances_batch), [])
                )
767
768
                pos = [0] + list(
                    itertools.accumulate(
769
770
771
                        len(instance.beams) for instance in instances_batch
                    )
                )
772
                instance_start_and_end: list[tuple[int, int]] = list(
773
774
                    zip(pos[:-1], pos[1:])
                )
775
776
777
778
779
780

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
781
                output = self._render_and_run_requests(
782
783
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
784
                    output_type=RequestOutput,
785
                    lora_requests=[beam.lora_request for beam in all_beams],
786
787
                    use_tqdm=False,
                )
788

789
790
791
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
792
793
794
795
796
797
798
799
800
801
802
803
804
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
805
                                    current_beam.orig_prompt,
806
                                    tokens=current_beam.tokens + [token_id],
807
                                    logprobs=current_beam.logprobs + [logprobs],
808
                                    lora_request=current_beam.lora_request,
809
810
811
812
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

813
                                if token_id == eos_token_id and not ignore_eos:
814
815
816
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
817
818
819
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
820
                    instance.beams = sorted_beams[:beam_width]
821
822
823
824

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
825
826
827
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
828
829
830
831
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
832

833
834
835
836
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

837
    def _preprocess_cmpl(
838
        self,
839
        prompts: Sequence[PromptType],
840
        tokenization_kwargs: dict[str, Any] | None = None,
841
    ) -> Sequence[ProcessorInputs]:
842
843
844
845
846
847
848
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
849
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
850
        """
851
        renderer = self.renderer
852
853
        model_config = self.model_config

854
855
856
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
857
858
859
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
860

861
        return renderer.render_cmpl(parsed_prompts, tok_params)
862

863
864
865
866
867
868
869
870
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_prompt

871
872
    def _preprocess_chat(
        self,
873
        conversations: Sequence[list[ChatCompletionMessageParam]],
874
        chat_template: str | None = None,
875
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
876
        chat_template_kwargs: dict[str, Any] | None = None,
877
        add_generation_prompt: bool = True,
878
        continue_final_message: bool = False,
879
        tools: list[dict[str, Any]] | None = None,
880
        tokenization_kwargs: dict[str, Any] | None = None,
881
        mm_processor_kwargs: dict[str, Any] | None = None,
882
    ) -> Sequence[ProcessorInputs]:
nunjunj's avatar
nunjunj committed
883
        """
884
885
886
887
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
888
889

        Returns:
890
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
891
        """
892
        renderer = self.renderer
893

894
895
896
897
898
899
900
901
902
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
903
                    tokenize=is_mistral_tokenizer(renderer.tokenizer),
904
905
906
                ),
            ),
        )
907
908
909
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
910

911
912
913
914
915
916
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
917

918
        return engine_prompts
919

920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_chat(
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        return engine_prompt

946
947
    def chat(
        self,
948
        messages: list[ChatCompletionMessageParam]
949
950
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
951
        use_tqdm: bool | Callable[..., tqdm] = True,
952
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
953
        chat_template: str | None = None,
954
955
956
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
957
958
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
959
        tokenization_kwargs: dict[str, Any] | None = None,
960
        mm_processor_kwargs: dict[str, Any] | None = None,
961
962
963
964
965
966
967
968
969
970
971
972
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
973
            messages: A sequence of conversations or a single conversation.
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
1005
1006
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
1007
1008
1009
1010
1011

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

1024
        return self._run_chat(
1025
1026
            messages=messages,
            params=sampling_params,
1027
            output_type=RequestOutput,
1028
1029
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1030
1031
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1032
            chat_template_kwargs=chat_template_kwargs,
1033
1034
1035
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1036
            tokenization_kwargs=tokenization_kwargs,
1037
1038
1039
            mm_processor_kwargs=mm_processor_kwargs,
        )

1040
1041
    def encode(
        self,
1042
1043
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1044
        *,
1045
1046
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1047
        pooling_task: PoolingTask | None = None,
1048
        tokenization_kwargs: dict[str, Any] | None = None,
1049
    ) -> list[PoolingRequestOutput]:
1050
1051
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1052

1053
        This class automatically batches the given prompts, considering
1054
1055
1056
1057
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1058
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1059
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1060
                for more details about the format of each prompt.
1061
1062
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1063
1064
1065
1066
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1067
            lora_request: LoRA request to use for generation, if any.
1068
            pooling_task: Override the pooling task to use.
1069
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1070
1071

        Returns:
1072
            A list of `PoolingRequestOutput` objects containing the
1073
            pooled hidden states in the same order as the input prompts.
1074
        """
1075

1076
        if pooling_task is None:
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1092

1093
        model_config = self.model_config
1094
        runner_type = model_config.runner_type
1095
        if runner_type != "pooling":
1096
1097
1098
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1099
1100
                "pooling model."
            )
1101

1102
        if isinstance(prompts, dict) and "data" in prompts:
1103
1104
1105
1106
1107
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1108
1109
                    "offline inference example for more details."
                )
1110
1111

            # Validate the request data is valid for the loaded plugin
1112
1113
1114
1115
1116
1117
1118
1119
1120
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1121
1122
1123

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1124
            prompts_seq = prompt_to_seq(prompts)
1125

1126
1127
1128
1129
1130
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1131
                )
1132
1133
1134
1135
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160

            outputs = self._run_completion(
                prompts=prompts_seq,
                params=params_seq,
                output_type=PoolingRequestOutput,
                use_tqdm=use_tqdm,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )

            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(outputs)

            return [
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
                    prompt_token_ids=[],
                    finished=True,
                )
            ]
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1177

1178
1179
1180
1181
1182
1183
1184
1185
1186
            if pooling_task in self.init_pooling_io_processors:
                io_processor = self.init_pooling_io_processors[pooling_task]
                processor_inputs = io_processor.pre_process_offline(
                    prompts_seq, tokenization_kwargs
                )
                seq_lora_requests = self._lora_request_to_seq(
                    lora_request, len(prompts_seq)
                )
                seq_priority = self._priority_to_seq(None, len(prompts))
1187

1188
1189
1190
1191
1192
                self._render_and_add_requests(
                    prompts=processor_inputs,
                    params=params_seq,
                    lora_requests=seq_lora_requests,
                    priorities=seq_priority,
1193
                )
1194

1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
                outputs = self._run_engine(
                    use_tqdm=use_tqdm, output_type=PoolingRequestOutput
                )
                outputs = io_processor.post_process(outputs)
            else:
                outputs = self._run_completion(
                    prompts=prompts_seq,
                    params=params_seq,
                    output_type=PoolingRequestOutput,
                    use_tqdm=use_tqdm,
                    lora_request=lora_request,
                    tokenization_kwargs=tokenization_kwargs,
                )
1208
        return outputs
1209

1210
1211
    def embed(
        self,
1212
        prompts: PromptType | Sequence[PromptType],
1213
        *,
1214
1215
1216
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1217
        tokenization_kwargs: dict[str, Any] | None = None,
1218
    ) -> list[EmbeddingRequestOutput]:
1219
1220
1221
1222
1223
1224
1225
1226
1227
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1228
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1229
                for more details about the format of each prompt.
1230
1231
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1232
1233
1234
1235
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1236
            lora_request: LoRA request to use for generation, if any.
1237
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1238
1239

        Returns:
1240
            A list of `EmbeddingRequestOutput` objects containing the
1241
1242
            embedding vectors in the same order as the input prompts.
        """
1243
        if "embed" not in self.supported_tasks:
1244
1245
            raise ValueError(
                "Embedding API is not supported by this model. "
1246
1247
                "Try converting the model using `--convert embed`."
            )
1248

1249
1250
1251
1252
1253
1254
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1255
            tokenization_kwargs=tokenization_kwargs,
1256
        )
1257
1258
1259
1260
1261

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1262
        prompts: PromptType | Sequence[PromptType],
1263
        *,
1264
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1265
        use_tqdm: bool | Callable[..., tqdm] = True,
1266
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1267
        tokenization_kwargs: dict[str, Any] | None = None,
1268
    ) -> list[ClassificationRequestOutput]:
1269
1270
1271
1272
1273
1274
1275
1276
1277
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1278
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1279
                for more details about the format of each prompt.
1280
1281
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1282
1283
1284
1285
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1286
            lora_request: LoRA request to use for generation, if any.
1287
1288
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1289
        Returns:
1290
            A list of `ClassificationRequestOutput` objects containing the
1291
1292
            embedding vectors in the same order as the input prompts.
        """
1293
        if "classify" not in self.supported_tasks:
1294
            raise ValueError(
1295
                "Classification API is not supported by this model. "
1296
1297
                "Try converting the model using `--convert classify`."
            )
1298

1299
1300
1301
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1302
            pooling_params=pooling_params,
1303
1304
            lora_request=lora_request,
            pooling_task="classify",
1305
            tokenization_kwargs=tokenization_kwargs,
1306
        )
1307
1308
1309

        return [ClassificationRequestOutput.from_base(item) for item in items]

1310
1311
    def reward(
        self,
1312
        prompts: PromptType | Sequence[PromptType],
1313
1314
        /,
        *,
1315
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1316
1317
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1318
        tokenization_kwargs: dict[str, Any] | None = None,
1319
1320
1321
1322
1323
1324
1325
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1326
                for more details about the format of each prompt.
1327
1328
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1329
1330
1331
1332
1333
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1334
1335
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1336
1337
1338
1339
1340
1341
1342
1343
1344
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """
        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
1345
            pooling_task="token_classify",
1346
            tokenization_kwargs=tokenization_kwargs,
1347
1348
        )

1349
1350
    def _embedding_score(
        self,
1351
1352
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1353
1354
1355
1356
1357
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1358
    ) -> list[ScoringRequestOutput]:
1359
1360
        tokenizer = self.get_tokenizer()

1361
1362
1363
1364
1365
1366
1367
1368
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1369
        encoded_output = self.encode(
1370
            input_texts,
1371
1372
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1373
            pooling_params=pooling_params,
1374
            pooling_task="embed",
1375
            tokenization_kwargs=tokenization_kwargs,
1376
        )
1377

1378
1379
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1380
1381
1382
1383

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1384
        scores = _cosine_similarity(
1385
1386
1387
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1388
        )
1389

1390
        return [ScoringRequestOutput.from_base(item) for item in scores]
1391

1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

1412
1413
1414
1415
        # Convert ScoreData to PromptType (handles both text and multimodal)
        model_config = self.model_config
        prompts_1 = score_data_to_prompts(data_1, "query", model_config)
        prompts_2 = score_data_to_prompts(data_2, "document", model_config)
1416

1417
1418
        encoded_output: list[PoolingRequestOutput] = self.encode(
            prompts_1 + prompts_2,
1419
1420
1421
1422
1423
1424
1425
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

1426
1427
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[: len(prompts_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(prompts_1) :]
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

1458
        return [ScoringRequestOutput.from_base(item) for item in scores]
1459

1460
1461
    def _cross_encoding_score(
        self,
1462
1463
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1464
1465
1466
1467
1468
1469
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1470
    ) -> list[ScoringRequestOutput]:
1471
        model_config = self.model_config
1472
        tokenizer = self.get_tokenizer()
1473

1474
        if is_mistral_tokenizer(tokenizer):
1475
            raise ValueError("Score API is not supported for Mistral tokenizer")
1476

1477
1478
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1479

1480
1481
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1482
1483
        elif pooling_params.task is None:
            pooling_params.task = "score"
1484

1485
        pooling_params_list = list[PoolingParams]()
1486

1487
        prompts = list[PromptType]()
1488

1489
1490
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1491
1492
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1493
                model_config=model_config,
1494
1495
1496
1497
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1498
                score_template=score_template,
1499
1500
            )

1501
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1502
1503
1504
1505
1506
1507
1508
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1509
            prompts.append(engine_prompt)
1510

1511
        outputs = self._run_completion(
1512
            prompts=prompts,
1513
            params=pooling_params_list,
1514
            output_type=PoolingRequestOutput,
1515
            use_tqdm=use_tqdm,
1516
1517
1518
            lora_request=lora_request,
        )

1519
        return [ScoringRequestOutput.from_base(item) for item in outputs]
1520

1521
1522
    def score(
        self,
1523
1524
1525
1526
1527
1528
1529
1530
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1531
        /,
1532
        *,
1533
1534
1535
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1536
        tokenization_kwargs: dict[str, Any] | None = None,
1537
        chat_template: str | None = None,
1538
    ) -> list[ScoringRequestOutput]:
1539
1540
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1541

1542
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1543
1544
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1545
        The input pairs are used to build a list of prompts for the
1546
1547
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1548
1549
1550
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1551
        appropriate multi-modal models. For multi-modal inputs, ensure the
1552
        prompt structure matches the model's expected input format.
1553
1554

        Args:
1555
1556
1557
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1558
                the `data_2` list.
1559
            data_2: The data to pair with the query to form the input to
1560
                the LLM. Can be text or multi-modal data. See [PromptType]
1561
                [vllm.inputs.PromptType] for more details about the format of
1562
                each prompt.
1563
1564
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1565
1566
1567
1568
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1569
            lora_request: LoRA request to use for generation, if any.
1570
1571
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1572
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1573
        Returns:
1574
            A list of `ScoringRequestOutput` objects containing the
1575
1576
            generated scores in the same order as the input prompts.
        """
1577
        model_config = self.model_config
1578

1579
        runner_type = model_config.runner_type
1580
        if runner_type != "pooling":
1581
1582
1583
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1584
1585
                "pooling model."
            )
1586

1587
        supported_tasks = self.supported_tasks
1588
1589
1590
1591
1592
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1593
1594
1595
1596
1597
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1598

1599
1600
1601
1602
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1603
            raise ValueError("Score API is only enabled for num_labels == 1.")
1604

1605
1606
1607
1608
1609
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1610
1611
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1612

1613
1614
1615
1616
1617
1618
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1619

1620
1621
1622
1623
        renderer = self.renderer
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
1624
1625
        encode_kwargs = tok_params.get_encode_kwargs()

1626
        if model_config.is_cross_encoder:
1627
            return self._cross_encoding_score(
1628
1629
                score_data_1,
                score_data_2,
1630
1631
1632
1633
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1634
                score_template=chat_template,
1635
            )
1636
1637
1638
1639
1640
1641
1642
1643
1644
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1645
        else:
1646
            return self._embedding_score(
1647
1648
                score_data_1,
                score_data_2,
1649
1650
1651
1652
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1653
            )
1654

1655
1656
1657
1658
1659
1660
1661
1662
1663
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1664
1665
1666
1667

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1668
1669
1670
1671
1672
1673
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1674

1675
    def sleep(self, level: int = 1, mode: PauseMode = "abort"):
1676
1677
1678
1679
1680
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1681
        Args:
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1695
1696
            mode: How to handle any existing requests, can be "abort", "wait",
                or "keep".
1697
        """
1698
        self.llm_engine.sleep(level=level, mode=mode)
1699

1700
    def wake_up(self, tags: list[str] | None = None):
1701
        """
1702
1703
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1704

1705
        Args:
1706
1707
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1708
1709
1710
1711
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1712
1713
        """
        self.llm_engine.wake_up(tags)
1714

1715
1716
1717
1718
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1719
            A `MetricSnapshot` instance capturing the current state
1720
1721
1722
1723
1724
1725
1726
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1727
    def _params_to_seq(
1728
        self,
1729
        params: _P | Sequence[_P],
1730
        num_requests: int,
1731
    ) -> Sequence[_P]:
1732
1733
1734
1735
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1736
                    f"and params ({len(params)}) must be the same."
1737
1738
                )

1739
            return params
1740

1741
1742
1743
1744
1745
1746
1747
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1748
1749
1750
1751
1752
1753
1754
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1755
1756
1757
            return lora_request

        return [lora_request] * num_requests
1758

1759
1760
1761
1762
1763
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1764
1765
1766
1767
1768
1769
1770
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1771
1772
1773
1774
            return priority

        return [0] * num_requests

1775
    def _add_completion_requests(
1776
1777
1778
1779
1780
1781
1782
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1783
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1784
1785
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
1786
    ) -> list[str]:
1787
1788
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1789
1790
1791
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_priority = self._priority_to_seq(priority, len(prompts))

1792
        return self._render_and_add_requests(
1793
            prompts=(
1794
1795
1796
1797
1798
                self._preprocess_cmpl_one(prompt, tokenization_kwargs)
                for prompt in maybe_tqdm(
                    seq_prompts,
                    use_tqdm=use_tqdm,
                    desc="Rendering prompts",
1799
                )
1800
            ),
1801
            params=seq_params,
1802
1803
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1804
1805
        )

1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        output_type: type[_O],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        self._add_completion_requests(
            prompts=prompts,
            params=params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
        )
        return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)

1829
1830
1831
1832
1833
1834
1835
    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1836
        output_type: type[_O],
1837
1838
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1839
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1840
1841
1842
1843
1844
1845
1846
1847
1848
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
1863
                    tokenization_kwargs=tokenization_kwargs,
1864
1865
                    mm_processor_kwargs=mm_processor_kwargs,
                )
1866
1867
1868
1869
                for conversation in maybe_tqdm(
                    seq_convs,
                    use_tqdm=use_tqdm,
                    desc="Rendering conversations",
1870
1871
1872
                )
            ),
            params=seq_params,
1873
            output_type=output_type,
1874
1875
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1876
1877
        )

1878
1879
1880
1881
    def _render_and_run_requests(
        self,
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1882
        output_type: type[_O],
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1900
            params=params,
1901
1902
            lora_requests=lora_requests,
            priorities=priorities,
1903
1904
        )

1905
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1906

1907
    def _render_and_add_requests(
1908
        self,
1909
1910
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1911
        *,
1912
1913
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1914
    ) -> list[str]:
1915
        added_request_ids: list[str] = []
1916

1917
        try:
1918
            for i, prompt in enumerate(prompts):
1919
1920
                request_id = self._add_request(
                    prompt,
1921
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
1922
1923
1924
1925
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
1926
                    priority=0 if priorities is None else priorities[i],
1927
1928
1929
1930
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1931
                self.llm_engine.abort_request(added_request_ids, internal=True)
1932
            raise e
1933

1934
1935
        return added_request_ids

1936
    def _add_request(
nunjunj's avatar
nunjunj committed
1937
        self,
1938
        prompt: ProcessorInputs,
1939
1940
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1941
        priority: int = 0,
1942
    ) -> str:
1943
1944
1945
1946
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1947
        request_id = str(next(self.request_counter))
1948

1949
        return self.llm_engine.add_request(
1950
            request_id,
1951
            prompt,
1952
1953
            params,
            lora_request=lora_request,
1954
            priority=priority,
nunjunj's avatar
nunjunj committed
1955
        )
1956

1957
    def _run_engine(
1958
        self,
1959
        output_type: type[_O] | tuple[type[_O], ...],
1960
1961
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1962
    ) -> list[_O]:
1963
1964
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1965
            num_requests = self.llm_engine.get_num_unfinished_requests()
1966
1967
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1968
1969
1970
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1971
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1972
            )
1973

Zhuohan Li's avatar
Zhuohan Li committed
1974
        # Run the engine.
1975
        outputs: list[_O] = []
1976
1977
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1978
1979
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1980
            for output in step_outputs:
1981
                assert isinstance(output, output_type)
1982
                if output.finished:
1983
                    outputs.append(output)  # type: ignore[arg-type]
1984
                    if use_tqdm:
1985
1986
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1987
                            n = len(output.outputs)
1988
                            assert output.prompt_token_ids is not None
1989
                            total_in_toks += len(output.prompt_token_ids) * n
1990
1991
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1992
1993
1994
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1995
1996
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1997
1998
                                f"output: {out_spd:.2f} toks/s"
                            )
1999
                            pbar.update(n)
2000
2001
                        else:
                            pbar.update(1)
2002
2003
                        if pbar.n == num_requests:
                            pbar.refresh()
2004

2005
2006
        if use_tqdm:
            pbar.close()
2007
2008
2009
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
2010
        return sorted(outputs, key=lambda x: int(x.request_id))
2011

2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr