llm.py 83.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
7
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
44
45
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
46
from vllm.entrypoints.pooling.score.utils import (
47
    ScoreData,
48
49
50
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
51
    compute_maxsim_score,
52
    get_score_prompt,
53
    validate_score_input,
54
)
55
from vllm.entrypoints.utils import log_non_default_args
56
from vllm.inputs.data import (
57
    DataPrompt,
58
    ProcessorInputs,
59
60
61
62
63
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
64
from vllm.logger import init_logger
65
from vllm.lora.request import LoRARequest
66
from vllm.model_executor.layers.quantization import QuantizationMethods
67
68
69
70
71
72
73
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
74
from vllm.platforms import current_platform
75
from vllm.pooling_params import PoolingParams
76
from vllm.renderers import ChatParams, merge_kwargs
77
78
79
80
81
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
82
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
83
from vllm.tasks import PoolingTask
84
85
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
86
from vllm.usage.usage_lib import UsageContext
87
from vllm.utils.counter import Counter
88
from vllm.utils.tqdm_utils import maybe_tqdm
89
from vllm.v1.engine.llm_engine import LLMEngine
90
from vllm.v1.sample.logits_processor import LogitsProcessor
91

92
93
94
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

95
96
logger = init_logger(__name__)

97
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
98
99
_R = TypeVar("_R", default=Any)

100
101

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
102
103
104
105
106
107
108
109
110
111
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
112
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
113
114
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
115
116
117
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
118
119
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
120
121
122
123
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
124
        allowed_media_domains: If set, only media URLs that belong to this
125
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
126
127
128
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
129
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
130
131
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
132
        quantization: The method used to quantize the model weights. Currently,
133
            we support "awq", "gptq", and "fp8" (experimental).
134
135
136
137
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
138
139
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
140
141
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
142
143
144
145
146
147
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
148
149
150
151
152
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
153
            compared with using gpu_memory_utilization. Note that
154
155
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
156
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
157
158
159
160
161
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
162
163
164
165
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
166
167
168
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
169
        enable_return_routed_experts: Whether to return routed experts.
170
171
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
172
        hf_token: The token to use as HTTP bearer authorization for remote files
173
            . If `True`, will use the token generated when running
174
            `hf auth login` (stored in `~/.cache/huggingface/token`).
175
176
177
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
178
179
180
181
182
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
183
184
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
185
        compilation_config: Either an integer or a dictionary. If it is an
186
            integer, it is used as the mode of compilation optimization. If it
187
            is a dictionary, it can specify the full compilation configuration.
188
189
190
191
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
192
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
193

194
195
    Note:
        This class is intended to be used for offline inference. For online
196
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
197
    """
198
199
200
201

    def __init__(
        self,
        model: str,
202
        *,
203
204
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
205
        tokenizer: str | None = None,
206
        tokenizer_mode: TokenizerMode | str = "auto",
207
        skip_tokenizer_init: bool = False,
208
        trust_remote_code: bool = False,
209
        allowed_local_media_path: str = "",
210
        allowed_media_domains: list[str] | None = None,
211
        tensor_parallel_size: int = 1,
212
        dtype: ModelDType = "auto",
213
214
215
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
216
        seed: int = 0,
217
        gpu_memory_utilization: float = 0.9,
218
        swap_space: float = 4,
219
        cpu_offload_gb: float = 0,
220
        enforce_eager: bool = False,
221
        enable_return_routed_experts: bool = False,
222
        disable_custom_all_reduce: bool = False,
223
224
225
226
227
228
229
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
230
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
231
        attention_config: dict[str, Any] | AttentionConfig | None = None,
232
233
234
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
235
        **kwargs: Any,
236
    ) -> None:
237
        """LLM constructor."""
238

239
240
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
241

242
243
244
245
246
247
248
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

249
        if "kv_transfer_config" in kwargs and isinstance(
250
251
            kwargs["kv_transfer_config"], dict
        ):
252
            from vllm.config.kv_transfer import KVTransferConfig
253

254
255
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
256
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
257
258
259
260
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
261
262
263
                    raw_config_dict,
                    e,
                )
264
265
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
266
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
267

268
269
270
        if hf_overrides is None:
            hf_overrides = {}

271
272
273
274
275
276
277
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
278

279
280
281
282
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
283
        else:
284
285
286
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
287

288
289
290
291
292
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
293

294
        # warn about single-process data parallel usage.
295
296
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
297
298
299
300
301
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
302
            raise ValueError(
303
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
304
305
306
307
308
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
309
        engine_args = EngineArgs(
310
            model=model,
311
312
            runner=runner,
            convert=convert,
313
            tokenizer=tokenizer,
314
            tokenizer_mode=tokenizer_mode,
315
            skip_tokenizer_init=skip_tokenizer_init,
316
            trust_remote_code=trust_remote_code,
317
            allowed_local_media_path=allowed_local_media_path,
318
            allowed_media_domains=allowed_media_domains,
319
320
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
321
            quantization=quantization,
322
            revision=revision,
323
            tokenizer_revision=tokenizer_revision,
324
325
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
326
            kv_cache_memory_bytes=kv_cache_memory_bytes,
327
            swap_space=swap_space,
328
            cpu_offload_gb=cpu_offload_gb,
329
            enforce_eager=enforce_eager,
330
            enable_return_routed_experts=enable_return_routed_experts,
331
            disable_custom_all_reduce=disable_custom_all_reduce,
332
            hf_token=hf_token,
333
            hf_overrides=hf_overrides,
334
            mm_processor_kwargs=mm_processor_kwargs,
335
            pooler_config=pooler_config,
336
            structured_outputs_config=structured_outputs_instance,
337
            profiler_config=profiler_config_instance,
338
            attention_config=attention_config_instance,
339
            compilation_config=compilation_config_instance,
340
            logits_processors=logits_processors,
341
342
            **kwargs,
        )
343

344
345
        log_non_default_args(engine_args)

346
        self.llm_engine = LLMEngine.from_engine_args(
347
348
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
349
        self.engine_class = type(self.llm_engine)
350

351
        self.request_counter = Counter()
352
        self.default_sampling_params: dict[str, Any] | None = None
353

354
355
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
356
357
        self.supported_tasks = supported_tasks

358
        self.model_config = self.llm_engine.model_config
359
        self.renderer = self.llm_engine.renderer
360
        self.io_processor = self.llm_engine.io_processor
361
        self.input_processor = self.llm_engine.input_processor
362

363
364
365
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

366
    def get_tokenizer(self) -> TokenizerLike:
367
        return self.llm_engine.get_tokenizer()
368

369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

386
    def reset_mm_cache(self) -> None:
387
        self.renderer.clear_mm_cache()
388
389
        self.llm_engine.reset_mm_cache()

390
    def get_default_sampling_params(self) -> SamplingParams:
391
        if self.default_sampling_params is None:
392
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
393
394
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
395
396
        return SamplingParams()

397
398
    def generate(
        self,
399
400
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
401
        *,
402
        use_tqdm: bool | Callable[..., tqdm] = True,
403
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
404
        priority: list[int] | None = None,
405
        tokenization_kwargs: dict[str, Any] | None = None,
406
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
407
408
        """Generates the completions for the input prompts.

409
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411
412
413
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
414
            prompts: The prompts to the LLM. You may pass a sequence of prompts
415
                for batch inference. See [PromptType][vllm.inputs.PromptType]
416
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
417
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
418
419
420
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
421
                prompts and it is paired one by one with the prompt.
422
423
424
425
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
426
            lora_request: LoRA request to use for generation, if any.
427
428
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
429
430
431
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
432
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
433
434

        Returns:
435
            A list of `RequestOutput` objects containing the
436
437
            generated completions in the same order as the input prompts.
        """
438
        model_config = self.model_config
439
440
        runner_type = model_config.runner_type
        if runner_type != "generate":
441
442
443
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
444
445
                "generative model."
            )
446

447
        if sampling_params is None:
448
            sampling_params = self.get_default_sampling_params()
449

450
        outputs = self._run_completion(
451
            prompts=prompts,
452
            params=sampling_params,
453
            use_tqdm=use_tqdm,
454
            lora_request=lora_request,
455
            tokenization_kwargs=tokenization_kwargs,
456
457
            priority=priority,
        )
458

Joe Runde's avatar
Joe Runde committed
459
        return self.engine_class.validate_outputs(outputs, RequestOutput)
460

461
462
463
464
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
465
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

        # Use the same preprocessing as _run_completion
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]
        seq_priority = self._priority_to_seq(priority, len(prompts))

        request_ids = self._render_and_add_requests(
            prompts=(
                self._preprocess_cmpl_one(prompt, tok_kwargs)
                for prompt, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_prompts,
                        use_tqdm=use_tqdm,
                        desc="Rendering prompts",
516
                    ),
517
                    seq_tok_kwargs,
518
519
                )
            ),
520
521
            params=seq_params,
            lora_requests=seq_lora_requests,
522
            tokenization_kwargs=tokenization_kwargs,
523
            priorities=seq_priority,
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
        )

        return request_ids

    def wait_for_completion(
        self,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[RequestOutput]:
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
            A list of RequestOutput objects for all completed requests.
        """
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return self.engine_class.validate_outputs(outputs, RequestOutput)

546
    def _resolve_lora_reqs(
547
        self,
548
549
        prompts: Sequence[ProcessorInputs],
        lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
550
    ):
551
        lora_config = self.llm_engine.vllm_config.lora_config
552
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(prompts))
553

554
555
        if (
            lora_config is None
556
            or not self.model_config.is_multimodal_model
557
558
            or (lora_config and lora_config.default_mm_loras is None)
        ):
559
            return seq_lora_requests
560
561
562

        return [
            self._resolve_single_prompt_mm_lora(
563
                prompt,
564
                lora_req,
565
                lora_config.default_mm_loras,
566
            )
567
            for prompt, lora_req in zip(prompts, seq_lora_requests)
568
569
        ]

570
571
    def _resolve_single_prompt_mm_lora(
        self,
572
        prompt: ProcessorInputs,
573
574
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
575
    ):
576
        if not default_mm_loras or prompt["type"] != "multimodal":
577
578
            return lora_request

579
580
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
581
582
583
584
585
586
587
588
589
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
590
591
592
                " will be skipped",
                intersection,
            )
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
608
609
                    "lora_request as we only apply one LoRARequest per prompt"
                )
610
611
612
613
614
615
616
617
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

618
619
    def collective_rpc(
        self,
620
621
        method: str | Callable[..., _R],
        timeout: float | None = None,
622
        args: tuple = (),
623
        kwargs: dict[str, Any] | None = None,
624
    ) -> list[_R]:
625
626
627
628
629
630
631
632
633
634
635
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
636
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
637
638
639
640
641
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
642

643
644
645
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
646
        """
647
648

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
649
650

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
651
        """
652
653
        Run a function directly on the model inside each worker,
        returning the result for each of them.
654
655
656
657
658
659

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
660
        """
661
        return self.llm_engine.apply_model(func)
662

663
664
    def beam_search(
        self,
665
        prompts: list[TokensPrompt | TextPrompt],
666
        params: BeamSearchParams,
667
        lora_request: list[LoRARequest] | LoRARequest | None = None,
668
        use_tqdm: bool = False,
669
        concurrency_limit: int | None = None,
670
    ) -> list[BeamSearchOutput]:
671
672
673
674
675
676
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
677
            params: The beam search parameters.
678
            lora_request: LoRA request to use for generation, if any.
679
            use_tqdm: Whether to use tqdm to display the progress bar.
680
681
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
682
        """
683
684
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
685
686
687
688
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
689
690
        length_penalty = params.length_penalty

691
692
693
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
694

695
696
        engine_prompts = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
697

698
699
700
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
701
702
                "Disabling progress bar."
            )
703
704
705
            use_tqdm = False

        if concurrency_limit is None:
706
            concurrency_limit = len(engine_prompts)
707

708
709
710
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
711
        sampling_params = SamplingParams(
712
713
714
715
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
716
        )
717
        instances: list[BeamSearchInstance] = []
718

719
720
721
722
723
724
725
726
727
        for lora_req, prompt in zip(lora_requests, engine_prompts):
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
            if prompt["type"] == "enc_dec":
                raise NotImplementedError(
                    "Encoder-decoder prompt not supported for beam search"
                )
728

729
            instances.append(
730
                BeamSearchInstance(
731
                    prompt,
732
733
                    lora_request=lora_req,
                    logprobs=None,
734
735
                ),
            )
736

737
        for prompt_start in range(0, len(instances), concurrency_limit):
738
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
739
740
741

            token_iter = range(max_tokens)
            if use_tqdm:
742
743
744
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
745
746
747
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
748
749
                    "reflect instance-level progress."
                )
750
751
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
752
753
                    sum((instance.beams for instance in instances_batch), [])
                )
754
755
                pos = [0] + list(
                    itertools.accumulate(
756
757
758
                        len(instance.beams) for instance in instances_batch
                    )
                )
759
                instance_start_and_end: list[tuple[int, int]] = list(
760
761
                    zip(pos[:-1], pos[1:])
                )
762
763
764
765
766
767

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
768
769
770
771
                raw_output = self._render_and_run_requests(
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
                    lora_requests=[beam.lora_request for beam in all_beams],
772
773
                    use_tqdm=False,
                )
774
                output = self.engine_class.validate_outputs(raw_output, RequestOutput)
775

776
777
778
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
779
780
781
782
783
784
785
786
787
788
789
790
791
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
792
                                    current_beam.orig_prompt,
793
                                    tokens=current_beam.tokens + [token_id],
794
                                    logprobs=current_beam.logprobs + [logprobs],
795
                                    lora_request=current_beam.lora_request,
796
797
798
799
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

800
                                if token_id == eos_token_id and not ignore_eos:
801
802
803
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
804
805
806
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
807
                    instance.beams = sorted_beams[:beam_width]
808
809
810
811

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
812
813
814
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
815
816
817
818
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
819

820
821
822
823
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

824
    def _preprocess_cmpl(
825
        self,
826
        prompts: Sequence[PromptType],
827
        tokenization_kwargs: dict[str, Any] | None = None,
828
    ) -> Sequence[ProcessorInputs]:
829
830
831
832
833
834
835
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
836
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
837
        """
838
        renderer = self.renderer
839
840
        model_config = self.model_config

841
842
843
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
844
845
846
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
847

848
        return renderer.render_cmpl(parsed_prompts, tok_params)
849

850
851
852
853
854
855
856
857
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_prompt

858
859
    def _preprocess_chat(
        self,
860
        conversations: Sequence[list[ChatCompletionMessageParam]],
861
        chat_template: str | None = None,
862
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
863
        chat_template_kwargs: dict[str, Any] | None = None,
864
        add_generation_prompt: bool = True,
865
        continue_final_message: bool = False,
866
        tools: list[dict[str, Any]] | None = None,
867
        tokenization_kwargs: dict[str, Any] | None = None,
868
        mm_processor_kwargs: dict[str, Any] | None = None,
869
    ) -> Sequence[ProcessorInputs]:
nunjunj's avatar
nunjunj committed
870
        """
871
872
873
874
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
875
876

        Returns:
877
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
878
        """
879
        renderer = self.renderer
880

881
882
883
884
885
886
887
888
889
890
891
892
893
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
                ),
            ),
        )
894
895
896
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
897

898
899
900
901
902
903
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
904

905
        return engine_prompts
906

907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_chat(
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        return engine_prompt

933
934
    def chat(
        self,
935
        messages: list[ChatCompletionMessageParam]
936
937
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
938
        use_tqdm: bool | Callable[..., tqdm] = True,
939
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
940
        chat_template: str | None = None,
941
942
943
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
944
945
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
946
        tokenization_kwargs: dict[str, Any] | None = None,
947
        mm_processor_kwargs: dict[str, Any] | None = None,
948
949
950
951
952
953
954
955
956
957
958
959
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
960
            messages: A sequence of conversations or a single conversation.
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
992
993
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
994
995
996
997
998

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

        outputs = self._run_chat(
            messages=messages,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1016
1017
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1018
            chat_template_kwargs=chat_template_kwargs,
1019
1020
1021
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1022
            tokenization_kwargs=tokenization_kwargs,
1023
1024
1025
            mm_processor_kwargs=mm_processor_kwargs,
        )

1026
        return self.engine_class.validate_outputs(outputs, RequestOutput)
nunjunj's avatar
nunjunj committed
1027

1028
1029
    def encode(
        self,
1030
1031
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1032
        *,
1033
1034
1035
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1036
        pooling_task: PoolingTask | None = None,
1037
        tokenization_kwargs: dict[str, Any] | None = None,
1038
    ) -> list[PoolingRequestOutput]:
1039
1040
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1041

1042
        This class automatically batches the given prompts, considering
1043
1044
1045
1046
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1047
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1048
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1049
                for more details about the format of each prompt.
1050
1051
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1052
1053
1054
1055
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1056
            lora_request: LoRA request to use for generation, if any.
1057
            pooling_task: Override the pooling task to use.
1058
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1059
1060

        Returns:
1061
            A list of `PoolingRequestOutput` objects containing the
1062
            pooled hidden states in the same order as the input prompts.
1063
        """
1064

1065
        if pooling_task is None:
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1081

1082
        model_config = self.model_config
1083
        runner_type = model_config.runner_type
1084
        if runner_type != "pooling":
1085
1086
1087
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1088
1089
                "pooling model."
            )
1090

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1105
        if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
1106
1107
1108
1109
1110
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1111
1112
                    "offline inference example for more details."
                )
1113
1114

            # Validate the request data is valid for the loaded plugin
1115
1116
1117
1118
1119
1120
1121
1122
1123
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1124
1125
1126

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1127
            prompts_seq = prompt_to_seq(prompts)
1128

1129
1130
1131
1132
1133
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1134
                )
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1155

1156
        outputs = self._run_completion(
1157
1158
            prompts=prompts_seq,
            params=params_seq,
1159
            use_tqdm=use_tqdm,
1160
            lora_request=lora_request,
1161
            tokenization_kwargs=tokenization_kwargs,
1162
1163
        )

1164
        model_outputs = self.engine_class.validate_outputs(
1165
1166
            outputs, PoolingRequestOutput
        )
1167

1168
        if use_io_processor:
1169
1170
            # get the post-processed model outputs
            assert self.io_processor is not None
1171
            processed_outputs = self.io_processor.post_process(model_outputs)
1172
1173

            return [
1174
1175
1176
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1177
1178
1179
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1180
1181
1182
                    prompt_token_ids=[],
                    finished=True,
                )
1183
1184
1185
            ]
        else:
            return model_outputs
1186

1187
1188
    def embed(
        self,
1189
        prompts: PromptType | Sequence[PromptType],
1190
        *,
1191
1192
1193
1194
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1195
        tokenization_kwargs: dict[str, Any] | None = None,
1196
    ) -> list[EmbeddingRequestOutput]:
1197
1198
1199
1200
1201
1202
1203
1204
1205
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1206
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1207
                for more details about the format of each prompt.
1208
1209
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1210
1211
1212
1213
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1214
            lora_request: LoRA request to use for generation, if any.
1215
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1216
1217

        Returns:
1218
            A list of `EmbeddingRequestOutput` objects containing the
1219
1220
            embedding vectors in the same order as the input prompts.
        """
1221
        if "embed" not in self.supported_tasks:
1222
1223
            raise ValueError(
                "Embedding API is not supported by this model. "
1224
1225
                "Try converting the model using `--convert embed`."
            )
1226

1227
1228
1229
1230
1231
1232
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1233
1234
1235
1236
1237
1238
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1239
            tokenization_kwargs=tokenization_kwargs,
1240
        )
1241
1242
1243
1244
1245

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1246
        prompts: PromptType | Sequence[PromptType],
1247
        *,
1248
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1249
        use_tqdm: bool | Callable[..., tqdm] = True,
1250
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1251
        tokenization_kwargs: dict[str, Any] | None = None,
1252
    ) -> list[ClassificationRequestOutput]:
1253
1254
1255
1256
1257
1258
1259
1260
1261
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1262
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1263
                for more details about the format of each prompt.
1264
1265
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1266
1267
1268
1269
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1270
            lora_request: LoRA request to use for generation, if any.
1271
1272
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1273
        Returns:
1274
            A list of `ClassificationRequestOutput` objects containing the
1275
1276
            embedding vectors in the same order as the input prompts.
        """
1277
        if "classify" not in self.supported_tasks:
1278
            raise ValueError(
1279
                "Classification API is not supported by this model. "
1280
1281
                "Try converting the model using `--convert classify`."
            )
1282

1283
1284
1285
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1286
            pooling_params=pooling_params,
1287
1288
            lora_request=lora_request,
            pooling_task="classify",
1289
            tokenization_kwargs=tokenization_kwargs,
1290
        )
1291
1292
1293

        return [ClassificationRequestOutput.from_base(item) for item in items]

1294
1295
    def reward(
        self,
1296
        prompts: PromptType | Sequence[PromptType],
1297
1298
        /,
        *,
1299
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1300
1301
1302
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1303
        tokenization_kwargs: dict[str, Any] | None = None,
1304
1305
1306
1307
1308
1309
1310
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1311
                for more details about the format of each prompt.
1312
1313
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1314
1315
1316
1317
1318
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1319
1320
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1332
            pooling_task="token_classify",
1333
            tokenization_kwargs=tokenization_kwargs,
1334
1335
        )

1336
1337
    def _embedding_score(
        self,
1338
1339
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1340
1341
1342
1343
1344
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1345
    ) -> list[ScoringRequestOutput]:
1346
1347
        tokenizer = self.get_tokenizer()

1348
1349
1350
1351
1352
1353
1354
1355
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1356
        encoded_output = self.encode(
1357
            input_texts,
1358
1359
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1360
            pooling_params=pooling_params,
1361
            pooling_task="embed",
1362
            tokenization_kwargs=tokenization_kwargs,
1363
        )
1364

1365
1366
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1367
1368
1369
1370

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1371
        scores = _cosine_similarity(
1372
1373
1374
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1375
        )
1376

1377
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1378
1379
        return [ScoringRequestOutput.from_base(item) for item in items]

1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

        # Extract text from ScoreData
        text_1: list[str] = []
        for text in data_1:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_1.append(text)

        text_2: list[str] = []
        for text in data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_2.append(text)

        encoded_output: list[PoolingRequestOutput] = self.encode(
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

1461
1462
    def _cross_encoding_score(
        self,
1463
1464
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1465
1466
1467
1468
1469
1470
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1471
    ) -> list[ScoringRequestOutput]:
1472
        model_config = self.model_config
1473
        tokenizer = self.get_tokenizer()
1474
1475

        if isinstance(tokenizer, MistralTokenizer):
1476
            raise ValueError("Score API is not supported for Mistral tokenizer")
1477

1478
1479
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1480

1481
1482
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1483
1484
        elif pooling_params.task is None:
            pooling_params.task = "score"
1485

1486
        pooling_params_list = list[PoolingParams]()
1487

1488
        prompts = list[PromptType]()
1489

1490
1491
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1492
1493
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1494
                model_config=model_config,
1495
1496
1497
1498
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1499
                score_template=score_template,
1500
1501
            )

1502
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1503
1504
1505
1506
1507
1508
1509
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1510
            prompts.append(engine_prompt)
1511

1512
        outputs = self._run_completion(
1513
            prompts=prompts,
1514
            params=pooling_params_list,
1515
            use_tqdm=use_tqdm,
1516
1517
1518
            lora_request=lora_request,
        )

1519
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1520
1521
1522

        return [ScoringRequestOutput.from_base(item) for item in items]

1523
1524
    def score(
        self,
1525
1526
1527
1528
1529
1530
1531
1532
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1533
        /,
1534
        *,
1535
1536
1537
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1538
        tokenization_kwargs: dict[str, Any] | None = None,
1539
        chat_template: str | None = None,
1540
    ) -> list[ScoringRequestOutput]:
1541
1542
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1543

1544
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1545
1546
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1547
        The input pairs are used to build a list of prompts for the
1548
1549
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1550
1551
1552
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1553
        appropriate multi-modal models. For multi-modal inputs, ensure the
1554
        prompt structure matches the model's expected input format.
1555
1556

        Args:
1557
1558
1559
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1560
                the `data_2` list.
1561
            data_2: The data to pair with the query to form the input to
1562
                the LLM. Can be text or multi-modal data. See [PromptType]
1563
                [vllm.inputs.PromptType] for more details about the format of
1564
                each prompt.
1565
1566
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1567
1568
1569
1570
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1571
            lora_request: LoRA request to use for generation, if any.
1572
1573
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1574
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1575
        Returns:
1576
            A list of `ScoringRequestOutput` objects containing the
1577
1578
            generated scores in the same order as the input prompts.
        """
1579
        model_config = self.model_config
1580

1581
        runner_type = model_config.runner_type
1582
        if runner_type != "pooling":
1583
1584
1585
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1586
1587
                "pooling model."
            )
1588

1589
        supported_tasks = self.supported_tasks
1590
1591
1592
1593
1594
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1595
1596
1597
1598
1599
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1600

1601
1602
1603
1604
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1605
            raise ValueError("Score API is only enabled for num_labels == 1.")
1606

1607
1608
1609
1610
1611
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1612
1613
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1614

1615
1616
1617
1618
1619
1620
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1621

1622
1623
1624
1625
        renderer = self.renderer
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
1626
1627
        encode_kwargs = tok_params.get_encode_kwargs()

1628
        if model_config.is_cross_encoder:
1629
            return self._cross_encoding_score(
1630
1631
                score_data_1,
                score_data_2,
1632
1633
1634
1635
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1636
                score_template=chat_template,
1637
            )
1638
1639
1640
1641
1642
1643
1644
1645
1646
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1647
        else:
1648
            return self._embedding_score(
1649
1650
                score_data_1,
                score_data_2,
1651
1652
1653
1654
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1655
            )
1656

1657
1658
1659
1660
1661
1662
1663
1664
1665
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1666
1667
1668
1669

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1670
1671
1672
1673
1674
1675
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1676

1677
1678
1679
1680
1681
1682
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1683
        Args:
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1697
        """
1698
1699
        if level > 0:
            self.reset_prefix_cache()
1700
1701
        self.llm_engine.sleep(level=level)

1702
    def wake_up(self, tags: list[str] | None = None):
1703
        """
1704
1705
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1706

1707
        Args:
1708
1709
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1710
1711
1712
1713
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1714
1715
        """
        self.llm_engine.wake_up(tags)
1716

1717
1718
1719
1720
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1721
            A `MetricSnapshot` instance capturing the current state
1722
1723
1724
1725
1726
1727
1728
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1729
    def _params_to_seq(
1730
        self,
1731
        params: _P | Sequence[_P],
1732
        num_requests: int,
1733
    ) -> Sequence[_P]:
1734
1735
1736
1737
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1738
                    f"and params ({len(params)}) must be the same."
1739
1740
                )

1741
            return params
1742

1743
1744
1745
1746
1747
1748
1749
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1750
1751
1752
1753
1754
1755
1756
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1757
1758
1759
            return lora_request

        return [lora_request] * num_requests
1760

1761
1762
1763
1764
1765
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1766
1767
1768
1769
1770
1771
1772
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
            return priority

        return [0] * num_requests

    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1785
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1786
1787
1788
1789
1790
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]
        seq_priority = self._priority_to_seq(priority, len(prompts))

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_cmpl_one(prompt, tok_kwargs)
                for prompt, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_prompts,
                        use_tqdm=use_tqdm,
                        desc="Rendering prompts",
1809
                    ),
1810
                    seq_tok_kwargs,
1811
                )
1812
            ),
1813
1814
            params=seq_params,
            use_tqdm=use_tqdm,
1815
            lora_requests=seq_lora_requests,
1816
            tokenization_kwargs=tokenization_kwargs,
1817
            priorities=seq_priority,
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
        )

    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1829
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1830
1831
1832
1833
1834
1835
1836
1837
1838
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenization_kwargs=tok_kwargs,
                    mm_processor_kwargs=mm_processor_kwargs,
                )
                for conversation, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_convs,
                        use_tqdm=use_tqdm,
                        desc="Rendering conversations",
                    ),
                    seq_tok_kwargs,
                )
            ),
            params=seq_params,
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1875
1876
1877
            tokenization_kwargs=tokenization_kwargs,
        )

1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
    def _render_and_run_requests(
        self,
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1900
            params=params,
1901
            lora_requests=lora_requests,
1902
            tokenization_kwargs=tokenization_kwargs,
1903
            priorities=priorities,
1904
1905
1906
1907
        )

        return self._run_engine(use_tqdm=use_tqdm)

1908
    def _render_and_add_requests(
1909
        self,
1910
1911
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1912
        *,
1913
        lora_requests: Sequence[LoRARequest | None] | None = None,
1914
        tokenization_kwargs: dict[str, Any] | None = None,
1915
        priorities: Sequence[int] | None = None,
1916
    ) -> list[str]:
1917
        added_request_ids: list[str] = []
1918

1919
        try:
1920
            for i, prompt in enumerate(prompts):
1921
1922
                request_id = self._add_request(
                    prompt,
1923
1924
                    params[i],
                    lora_request=None if lora_requests is None else lora_requests[i],
1925
                    tokenization_kwargs=tokenization_kwargs,
1926
                    priority=0 if priorities is None else priorities[i],
1927
1928
1929
1930
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1931
                self.llm_engine.abort_request(added_request_ids, internal=True)
1932
            raise e
1933

1934
1935
        return added_request_ids

1936
    def _add_request(
nunjunj's avatar
nunjunj committed
1937
        self,
1938
        prompt: ProcessorInputs,
1939
1940
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1941
        tokenization_kwargs: dict[str, Any] | None = None,
1942
        priority: int = 0,
1943
    ) -> str:
1944
1945
1946
1947
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1948
        request_id = str(next(self.request_counter))
1949

1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )

1965
        return self.llm_engine.add_request(
1966
            request_id,
1967
            prompt,
1968
1969
            params,
            lora_request=lora_request,
1970
            tokenization_kwargs=tokenization_kwargs,
1971
            priority=priority,
nunjunj's avatar
nunjunj committed
1972
        )
1973

1974
    def _run_engine(
1975
1976
1977
        self,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1978
    ) -> list[RequestOutput | PoolingRequestOutput]:
1979
1980
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1981
            num_requests = self.llm_engine.get_num_unfinished_requests()
1982
1983
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1984
1985
1986
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1987
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1988
            )
1989

Zhuohan Li's avatar
Zhuohan Li committed
1990
        # Run the engine.
1991
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1992
1993
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1994
1995
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1996
            for output in step_outputs:
1997
                if output.finished:
1998
1999
                    outputs.append(output)
                    if use_tqdm:
2000
2001
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
2002
                            n = len(output.outputs)
2003
                            assert output.prompt_token_ids is not None
2004
                            total_in_toks += len(output.prompt_token_ids) * n
2005
2006
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
2007
2008
2009
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
2010
2011
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
2012
2013
                                f"output: {out_spd:.2f} toks/s"
                            )
2014
                            pbar.update(n)
2015
2016
                        else:
                            pbar.update(1)
2017
2018
                        if pbar.n == num_requests:
                            pbar.refresh()
2019

2020
2021
        if use_tqdm:
            pbar.close()
2022
2023
2024
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
2025
        return sorted(outputs, key=lambda x: int(x.request_id))
2026

2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr