llm.py 82.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
7
from collections.abc import Callable, Iterable, Sequence
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, overload
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
44
45
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
46
from vllm.entrypoints.pooling.score.utils import (
47
    ScoreData,
48
49
50
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
51
    compute_maxsim_score,
52
    get_score_prompt,
53
    validate_score_input,
54
)
55
from vllm.entrypoints.utils import log_non_default_args
56
from vllm.inputs.data import (
57
    DataPrompt,
58
    ProcessorInputs,
59
60
61
62
63
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
64
from vllm.logger import init_logger
65
from vllm.lora.request import LoRARequest
66
from vllm.model_executor.layers.quantization import QuantizationMethods
67
68
69
70
71
72
73
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
74
from vllm.platforms import current_platform
75
from vllm.pooling_params import PoolingParams
76
from vllm.renderers import ChatParams, merge_kwargs
77
78
79
80
81
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
82
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
83
from vllm.tasks import PoolingTask
84
85
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
86
from vllm.usage.usage_lib import UsageContext
87
from vllm.utils.counter import Counter
88
from vllm.utils.tqdm_utils import maybe_tqdm
89
from vllm.v1.engine.llm_engine import LLMEngine
90
from vllm.v1.sample.logits_processor import LogitsProcessor
91

92
93
94
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

95
96
logger = init_logger(__name__)

97
98
99
100
101
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
102
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
103
104
_R = TypeVar("_R", default=Any)

105
106

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
109
110
111
112
113
114
115
116
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
117
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
118
119
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
120
121
122
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
123
124
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
125
126
127
128
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
129
        allowed_media_domains: If set, only media URLs that belong to this
130
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
135
136
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
137
        quantization: The method used to quantize the model weights. Currently,
138
            we support "awq", "gptq", and "fp8" (experimental).
139
140
141
142
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
143
144
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
145
146
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
147
148
149
150
151
152
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
153
154
155
156
157
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
158
            compared with using gpu_memory_utilization. Note that
159
160
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
161
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
162
163
164
165
166
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
167
168
169
170
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
171
172
173
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
174
        enable_return_routed_experts: Whether to return routed experts.
175
176
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
177
        hf_token: The token to use as HTTP bearer authorization for remote files
178
            . If `True`, will use the token generated when running
179
            `hf auth login` (stored in `~/.cache/huggingface/token`).
180
181
182
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
183
184
185
186
187
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
188
189
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
190
        compilation_config: Either an integer or a dictionary. If it is an
191
            integer, it is used as the mode of compilation optimization. If it
192
            is a dictionary, it can specify the full compilation configuration.
193
194
195
196
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
197
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
198

199
200
    Note:
        This class is intended to be used for offline inference. For online
201
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
202
    """
203
204
205
206

    def __init__(
        self,
        model: str,
207
        *,
208
209
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
210
        tokenizer: str | None = None,
211
        tokenizer_mode: TokenizerMode | str = "auto",
212
        skip_tokenizer_init: bool = False,
213
        trust_remote_code: bool = False,
214
        allowed_local_media_path: str = "",
215
        allowed_media_domains: list[str] | None = None,
216
        tensor_parallel_size: int = 1,
217
        dtype: ModelDType = "auto",
218
219
220
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
221
        seed: int = 0,
222
        gpu_memory_utilization: float = 0.9,
223
        swap_space: float = 4,
224
        cpu_offload_gb: float = 0,
225
        enforce_eager: bool = False,
226
        enable_return_routed_experts: bool = False,
227
        disable_custom_all_reduce: bool = False,
228
229
230
231
232
233
234
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
235
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
236
        attention_config: dict[str, Any] | AttentionConfig | None = None,
237
238
239
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
240
        **kwargs: Any,
241
    ) -> None:
242
        """LLM constructor."""
243

244
245
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
246

247
248
249
250
251
252
253
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

254
        if "kv_transfer_config" in kwargs and isinstance(
255
256
            kwargs["kv_transfer_config"], dict
        ):
257
            from vllm.config.kv_transfer import KVTransferConfig
258

259
260
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
261
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
262
263
264
265
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
266
267
268
                    raw_config_dict,
                    e,
                )
269
270
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
271
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
272

273
274
275
        if hf_overrides is None:
            hf_overrides = {}

276
277
278
279
280
281
282
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
283

284
285
286
287
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
288
        else:
289
290
291
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
292

293
294
295
296
297
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
298

299
        # warn about single-process data parallel usage.
300
301
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
302
303
304
305
306
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
307
            raise ValueError(
308
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
309
310
311
312
313
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
314
        engine_args = EngineArgs(
315
            model=model,
316
317
            runner=runner,
            convert=convert,
318
            tokenizer=tokenizer,
319
            tokenizer_mode=tokenizer_mode,
320
            skip_tokenizer_init=skip_tokenizer_init,
321
            trust_remote_code=trust_remote_code,
322
            allowed_local_media_path=allowed_local_media_path,
323
            allowed_media_domains=allowed_media_domains,
324
325
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
326
            quantization=quantization,
327
            revision=revision,
328
            tokenizer_revision=tokenizer_revision,
329
330
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
331
            kv_cache_memory_bytes=kv_cache_memory_bytes,
332
            swap_space=swap_space,
333
            cpu_offload_gb=cpu_offload_gb,
334
            enforce_eager=enforce_eager,
335
            enable_return_routed_experts=enable_return_routed_experts,
336
            disable_custom_all_reduce=disable_custom_all_reduce,
337
            hf_token=hf_token,
338
            hf_overrides=hf_overrides,
339
            mm_processor_kwargs=mm_processor_kwargs,
340
            pooler_config=pooler_config,
341
            structured_outputs_config=structured_outputs_instance,
342
            profiler_config=profiler_config_instance,
343
            attention_config=attention_config_instance,
344
            compilation_config=compilation_config_instance,
345
            logits_processors=logits_processors,
346
347
            **kwargs,
        )
348

349
350
        log_non_default_args(engine_args)

351
        self.llm_engine = LLMEngine.from_engine_args(
352
353
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
354
        self.engine_class = type(self.llm_engine)
355

356
        self.request_counter = Counter()
357
        self.default_sampling_params: dict[str, Any] | None = None
358

359
360
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
361
362
        self.supported_tasks = supported_tasks

363
        self.model_config = self.llm_engine.model_config
364
        self.renderer = self.llm_engine.renderer
365
        self.io_processor = self.llm_engine.io_processor
366
        self.input_processor = self.llm_engine.input_processor
367

368
369
370
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

371
    def get_tokenizer(self) -> TokenizerLike:
372
        return self.llm_engine.get_tokenizer()
373

374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

391
    def reset_mm_cache(self) -> None:
392
        self.renderer.clear_mm_cache()
393
394
        self.llm_engine.reset_mm_cache()

395
    def get_default_sampling_params(self) -> SamplingParams:
396
        if self.default_sampling_params is None:
397
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
398
399
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
400
401
        return SamplingParams()

402
403
    def generate(
        self,
404
405
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
406
        *,
407
        use_tqdm: bool | Callable[..., tqdm] = True,
408
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
409
        priority: list[int] | None = None,
410
        tokenization_kwargs: dict[str, Any] | None = None,
411
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
        """Generates the completions for the input prompts.

414
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
418
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
419
            prompts: The prompts to the LLM. You may pass a sequence of prompts
420
                for batch inference. See [PromptType][vllm.inputs.PromptType]
421
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
422
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
423
424
425
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
426
                prompts and it is paired one by one with the prompt.
427
428
429
430
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
431
            lora_request: LoRA request to use for generation, if any.
432
433
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
434
435
436
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
437
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
438
439

        Returns:
440
            A list of `RequestOutput` objects containing the
441
442
            generated completions in the same order as the input prompts.
        """
443
        model_config = self.model_config
444
445
        runner_type = model_config.runner_type
        if runner_type != "generate":
446
447
448
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
449
450
                "generative model."
            )
451

452
        if sampling_params is None:
453
            sampling_params = self.get_default_sampling_params()
454

455
        return self._run_completion(
456
            prompts=prompts,
457
            params=sampling_params,
458
            output_type=RequestOutput,
459
            use_tqdm=use_tqdm,
460
            lora_request=lora_request,
461
            tokenization_kwargs=tokenization_kwargs,
462
463
            priority=priority,
        )
464

465
466
467
468
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
469
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

        # Use the same preprocessing as _run_completion
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(sampling_params, len(seq_prompts))
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]
        seq_priority = self._priority_to_seq(priority, len(prompts))

        request_ids = self._render_and_add_requests(
            prompts=(
                self._preprocess_cmpl_one(prompt, tok_kwargs)
                for prompt, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_prompts,
                        use_tqdm=use_tqdm,
                        desc="Rendering prompts",
520
                    ),
521
                    seq_tok_kwargs,
522
523
                )
            ),
524
525
526
            params=seq_params,
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
527
528
529
530
        )

        return request_ids

531
    @overload
532
533
    def wait_for_completion(
        self,
534
        *,
535
        use_tqdm: bool | Callable[..., tqdm] = True,
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
552
553
554
555
556
557
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
558
            output_type: The expected output type, defaults to RequestOutput.
559
560
561
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
562
            A list of output objects for all completed requests.
563
        """
564
565
566
567
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
568

Cyrus Leung's avatar
Cyrus Leung committed
569
    def _resolve_mm_lora(
570
        self,
571
        prompt: ProcessorInputs,
572
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
573
574
575
576
577
578
579
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
580
581
            return lora_request

582
583
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
584
585
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
586

587
588
589
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
590
591
592
593
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
594
595
                intersection,
            )
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
611
612
                    "lora_request as we only apply one LoRARequest per prompt"
                )
613
614
615
616
617
618
619
620
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

621
622
    def collective_rpc(
        self,
623
624
        method: str | Callable[..., _R],
        timeout: float | None = None,
625
        args: tuple = (),
626
        kwargs: dict[str, Any] | None = None,
627
    ) -> list[_R]:
628
629
630
631
632
633
634
635
636
637
638
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
639
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
640
641
642
643
644
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
645

646
647
648
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
649
        """
650
651

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
652
653

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
654
        """
655
656
        Run a function directly on the model inside each worker,
        returning the result for each of them.
657
658
659
660
661
662

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
663
        """
664
        return self.llm_engine.apply_model(func)
665

666
667
    def beam_search(
        self,
668
        prompts: list[TokensPrompt | TextPrompt],
669
        params: BeamSearchParams,
670
        lora_request: list[LoRARequest] | LoRARequest | None = None,
671
        use_tqdm: bool = False,
672
        concurrency_limit: int | None = None,
673
    ) -> list[BeamSearchOutput]:
674
675
676
677
678
679
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
680
            params: The beam search parameters.
681
            lora_request: LoRA request to use for generation, if any.
682
            use_tqdm: Whether to use tqdm to display the progress bar.
683
684
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
685
        """
686
687
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
688
689
690
691
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
692
693
        length_penalty = params.length_penalty

694
695
696
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
697

698
699
        engine_prompts = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_prompts))
700

701
702
703
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
704
705
                "Disabling progress bar."
            )
706
707
708
            use_tqdm = False

        if concurrency_limit is None:
709
            concurrency_limit = len(engine_prompts)
710

711
712
713
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
714
        sampling_params = SamplingParams(
715
716
717
718
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
719
        )
720
        instances: list[BeamSearchInstance] = []
721

722
723
724
725
726
727
728
729
730
        for lora_req, prompt in zip(lora_requests, engine_prompts):
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
            if prompt["type"] == "enc_dec":
                raise NotImplementedError(
                    "Encoder-decoder prompt not supported for beam search"
                )
731

732
            instances.append(
733
                BeamSearchInstance(
734
                    prompt,
735
736
                    lora_request=lora_req,
                    logprobs=None,
737
738
                ),
            )
739

740
        for prompt_start in range(0, len(instances), concurrency_limit):
741
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
742
743
744

            token_iter = range(max_tokens)
            if use_tqdm:
745
746
747
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
748
749
750
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
751
752
                    "reflect instance-level progress."
                )
753
754
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
755
756
                    sum((instance.beams for instance in instances_batch), [])
                )
757
758
                pos = [0] + list(
                    itertools.accumulate(
759
760
761
                        len(instance.beams) for instance in instances_batch
                    )
                )
762
                instance_start_and_end: list[tuple[int, int]] = list(
763
764
                    zip(pos[:-1], pos[1:])
                )
765
766
767
768
769
770

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
771
                output = self._render_and_run_requests(
772
773
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
774
                    output_type=RequestOutput,
775
                    lora_requests=[beam.lora_request for beam in all_beams],
776
777
                    use_tqdm=False,
                )
778

779
780
781
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
782
783
784
785
786
787
788
789
790
791
792
793
794
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
795
                                    current_beam.orig_prompt,
796
                                    tokens=current_beam.tokens + [token_id],
797
                                    logprobs=current_beam.logprobs + [logprobs],
798
                                    lora_request=current_beam.lora_request,
799
800
801
802
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

803
                                if token_id == eos_token_id and not ignore_eos:
804
805
806
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
807
808
809
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
810
                    instance.beams = sorted_beams[:beam_width]
811
812
813
814

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
815
816
817
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
818
819
820
821
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
822

823
824
825
826
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

827
    def _preprocess_cmpl(
828
        self,
829
        prompts: Sequence[PromptType],
830
        tokenization_kwargs: dict[str, Any] | None = None,
831
    ) -> Sequence[ProcessorInputs]:
832
833
834
835
836
837
838
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
839
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
840
        """
841
        renderer = self.renderer
842
843
        model_config = self.model_config

844
845
846
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
847
848
849
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
850

851
        return renderer.render_cmpl(parsed_prompts, tok_params)
852

853
854
855
856
857
858
859
860
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_prompt

861
862
    def _preprocess_chat(
        self,
863
        conversations: Sequence[list[ChatCompletionMessageParam]],
864
        chat_template: str | None = None,
865
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
866
        chat_template_kwargs: dict[str, Any] | None = None,
867
        add_generation_prompt: bool = True,
868
        continue_final_message: bool = False,
869
        tools: list[dict[str, Any]] | None = None,
870
        tokenization_kwargs: dict[str, Any] | None = None,
871
        mm_processor_kwargs: dict[str, Any] | None = None,
872
    ) -> Sequence[ProcessorInputs]:
nunjunj's avatar
nunjunj committed
873
        """
874
875
876
877
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
878
879

        Returns:
880
            A list of `ProcessorInputs` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
881
        """
882
        renderer = self.renderer
883

884
885
886
887
888
889
890
891
892
893
894
895
896
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
                ),
            ),
        )
897
898
899
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
900

901
902
903
904
905
906
        _, engine_prompts = renderer.render_chat(
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
907

908
        return engine_prompts
909

910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ) -> ProcessorInputs:
        (engine_prompt,) = self._preprocess_chat(
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

        return engine_prompt

936
937
    def chat(
        self,
938
        messages: list[ChatCompletionMessageParam]
939
940
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
941
        use_tqdm: bool | Callable[..., tqdm] = True,
942
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
943
        chat_template: str | None = None,
944
945
946
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
947
948
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
949
        tokenization_kwargs: dict[str, Any] | None = None,
950
        mm_processor_kwargs: dict[str, Any] | None = None,
951
952
953
954
955
956
957
958
959
960
961
962
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
963
            messages: A sequence of conversations or a single conversation.
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
995
996
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
997
998
999
1000
1001

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

1014
        return self._run_chat(
1015
1016
            messages=messages,
            params=sampling_params,
1017
            output_type=RequestOutput,
1018
1019
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1020
1021
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1022
            chat_template_kwargs=chat_template_kwargs,
1023
1024
1025
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1026
            tokenization_kwargs=tokenization_kwargs,
1027
1028
1029
            mm_processor_kwargs=mm_processor_kwargs,
        )

1030
1031
    def encode(
        self,
1032
1033
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1034
        *,
1035
1036
1037
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1038
        pooling_task: PoolingTask | None = None,
1039
        tokenization_kwargs: dict[str, Any] | None = None,
1040
    ) -> list[PoolingRequestOutput]:
1041
1042
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1043

1044
        This class automatically batches the given prompts, considering
1045
1046
1047
1048
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1049
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1050
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1051
                for more details about the format of each prompt.
1052
1053
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1054
1055
1056
1057
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1058
            lora_request: LoRA request to use for generation, if any.
1059
            pooling_task: Override the pooling task to use.
1060
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1061
1062

        Returns:
1063
            A list of `PoolingRequestOutput` objects containing the
1064
            pooled hidden states in the same order as the input prompts.
1065
        """
1066

1067
        if pooling_task is None:
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1083

1084
        model_config = self.model_config
1085
        runner_type = model_config.runner_type
1086
        if runner_type != "pooling":
1087
1088
1089
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1090
1091
                "pooling model."
            )
1092

1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1107
        if use_io_processor := (isinstance(prompts, dict) and "data" in prompts):
1108
1109
1110
1111
1112
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1113
1114
                    "offline inference example for more details."
                )
1115
1116

            # Validate the request data is valid for the loaded plugin
1117
1118
1119
1120
1121
1122
1123
1124
1125
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1126
1127
1128

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1129
            prompts_seq = prompt_to_seq(prompts)
1130

1131
1132
1133
1134
1135
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1136
                )
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1157

1158
        outputs = self._run_completion(
1159
1160
            prompts=prompts_seq,
            params=params_seq,
1161
            output_type=PoolingRequestOutput,
1162
            use_tqdm=use_tqdm,
1163
            lora_request=lora_request,
1164
            tokenization_kwargs=tokenization_kwargs,
1165
1166
        )

1167
        if use_io_processor:
1168
1169
            # get the post-processed model outputs
            assert self.io_processor is not None
1170
            processed_outputs = self.io_processor.post_process(outputs)
1171
1172

            return [
1173
1174
1175
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1176
1177
1178
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1179
1180
1181
                    prompt_token_ids=[],
                    finished=True,
                )
1182
            ]
1183
1184

        return outputs
1185

1186
1187
    def embed(
        self,
1188
        prompts: PromptType | Sequence[PromptType],
1189
        *,
1190
1191
1192
1193
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1194
        tokenization_kwargs: dict[str, Any] | None = None,
1195
    ) -> list[EmbeddingRequestOutput]:
1196
1197
1198
1199
1200
1201
1202
1203
1204
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1205
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1206
                for more details about the format of each prompt.
1207
1208
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1209
1210
1211
1212
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1213
            lora_request: LoRA request to use for generation, if any.
1214
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1215
1216

        Returns:
1217
            A list of `EmbeddingRequestOutput` objects containing the
1218
1219
            embedding vectors in the same order as the input prompts.
        """
1220
        if "embed" not in self.supported_tasks:
1221
1222
            raise ValueError(
                "Embedding API is not supported by this model. "
1223
1224
                "Try converting the model using `--convert embed`."
            )
1225

1226
1227
1228
1229
1230
1231
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1232
1233
1234
1235
1236
1237
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1238
            tokenization_kwargs=tokenization_kwargs,
1239
        )
1240
1241
1242
1243
1244

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1245
        prompts: PromptType | Sequence[PromptType],
1246
        *,
1247
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1248
        use_tqdm: bool | Callable[..., tqdm] = True,
1249
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1250
        tokenization_kwargs: dict[str, Any] | None = None,
1251
    ) -> list[ClassificationRequestOutput]:
1252
1253
1254
1255
1256
1257
1258
1259
1260
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1261
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1262
                for more details about the format of each prompt.
1263
1264
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1265
1266
1267
1268
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1269
            lora_request: LoRA request to use for generation, if any.
1270
1271
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1272
        Returns:
1273
            A list of `ClassificationRequestOutput` objects containing the
1274
1275
            embedding vectors in the same order as the input prompts.
        """
1276
        if "classify" not in self.supported_tasks:
1277
            raise ValueError(
1278
                "Classification API is not supported by this model. "
1279
1280
                "Try converting the model using `--convert classify`."
            )
1281

1282
1283
1284
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1285
            pooling_params=pooling_params,
1286
1287
            lora_request=lora_request,
            pooling_task="classify",
1288
            tokenization_kwargs=tokenization_kwargs,
1289
        )
1290
1291
1292

        return [ClassificationRequestOutput.from_base(item) for item in items]

1293
1294
    def reward(
        self,
1295
        prompts: PromptType | Sequence[PromptType],
1296
1297
        /,
        *,
1298
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1299
1300
1301
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1302
        tokenization_kwargs: dict[str, Any] | None = None,
1303
1304
1305
1306
1307
1308
1309
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1310
                for more details about the format of each prompt.
1311
1312
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1313
1314
1315
1316
1317
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1318
1319
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1331
            pooling_task="token_classify",
1332
            tokenization_kwargs=tokenization_kwargs,
1333
1334
        )

1335
1336
    def _embedding_score(
        self,
1337
1338
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1339
1340
1341
1342
1343
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1344
    ) -> list[ScoringRequestOutput]:
1345
1346
        tokenizer = self.get_tokenizer()

1347
1348
1349
1350
1351
1352
1353
1354
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1355
        encoded_output = self.encode(
1356
            input_texts,
1357
1358
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1359
            pooling_params=pooling_params,
1360
            pooling_task="embed",
1361
            tokenization_kwargs=tokenization_kwargs,
1362
        )
1363

1364
1365
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1366
1367
1368
1369

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1370
        scores = _cosine_similarity(
1371
1372
1373
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1374
        )
1375

1376
        return [ScoringRequestOutput.from_base(item) for item in scores]
1377

1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

        # Extract text from ScoreData
        text_1: list[str] = []
        for text in data_1:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_1.append(text)

        text_2: list[str] = []
        for text in data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_2.append(text)

1415
        encoded_output = self.encode(
1416
1417
1418
1419
1420
1421
1422
1423
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

1424
1425
        encoded_output_1 = encoded_output[0 : len(text_1)]
        encoded_output_2 = encoded_output[len(text_1) :]
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

1456
        return [ScoringRequestOutput.from_base(item) for item in scores]
1457

1458
1459
    def _cross_encoding_score(
        self,
1460
1461
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1462
1463
1464
1465
1466
1467
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1468
    ) -> list[ScoringRequestOutput]:
1469
        model_config = self.model_config
1470
        tokenizer = self.get_tokenizer()
1471
1472

        if isinstance(tokenizer, MistralTokenizer):
1473
            raise ValueError("Score API is not supported for Mistral tokenizer")
1474

1475
1476
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1477

1478
1479
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")
1480
1481
        elif pooling_params.task is None:
            pooling_params.task = "score"
1482

1483
        pooling_params_list = list[PoolingParams]()
1484

1485
        prompts = list[PromptType]()
1486

1487
1488
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1489
1490
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1491
                model_config=model_config,
1492
1493
1494
1495
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1496
                score_template=score_template,
1497
1498
            )

1499
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1500
1501
1502
1503
1504
1505
1506
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1507
            prompts.append(engine_prompt)
1508

1509
        outputs = self._run_completion(
1510
            prompts=prompts,
1511
            params=pooling_params_list,
1512
            output_type=PoolingRequestOutput,
1513
            use_tqdm=use_tqdm,
1514
1515
1516
            lora_request=lora_request,
        )

1517
        return [ScoringRequestOutput.from_base(item) for item in outputs]
1518

1519
1520
    def score(
        self,
1521
1522
1523
1524
1525
1526
1527
1528
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1529
        /,
1530
        *,
1531
1532
1533
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1534
        tokenization_kwargs: dict[str, Any] | None = None,
1535
        chat_template: str | None = None,
1536
    ) -> list[ScoringRequestOutput]:
1537
1538
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1539

1540
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1541
1542
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1543
        The input pairs are used to build a list of prompts for the
1544
1545
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1546
1547
1548
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1549
        appropriate multi-modal models. For multi-modal inputs, ensure the
1550
        prompt structure matches the model's expected input format.
1551
1552

        Args:
1553
1554
1555
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1556
                the `data_2` list.
1557
            data_2: The data to pair with the query to form the input to
1558
                the LLM. Can be text or multi-modal data. See [PromptType]
1559
                [vllm.inputs.PromptType] for more details about the format of
1560
                each prompt.
1561
1562
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1563
1564
1565
1566
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1567
            lora_request: LoRA request to use for generation, if any.
1568
1569
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1570
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1571
        Returns:
1572
            A list of `ScoringRequestOutput` objects containing the
1573
1574
            generated scores in the same order as the input prompts.
        """
1575
        model_config = self.model_config
1576

1577
        runner_type = model_config.runner_type
1578
        if runner_type != "pooling":
1579
1580
1581
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1582
1583
                "pooling model."
            )
1584

1585
        supported_tasks = self.supported_tasks
1586
1587
1588
1589
1590
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1591
1592
1593
1594
1595
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1596

1597
1598
1599
1600
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1601
            raise ValueError("Score API is only enabled for num_labels == 1.")
1602

1603
1604
1605
1606
1607
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1608
1609
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1610

1611
1612
1613
1614
1615
1616
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1617

1618
1619
1620
1621
        renderer = self.renderer
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
1622
1623
        encode_kwargs = tok_params.get_encode_kwargs()

1624
        if model_config.is_cross_encoder:
1625
            return self._cross_encoding_score(
1626
1627
                score_data_1,
                score_data_2,
1628
1629
1630
1631
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1632
                score_template=chat_template,
1633
            )
1634
1635
1636
1637
1638
1639
1640
1641
1642
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1643
        else:
1644
            return self._embedding_score(
1645
1646
                score_data_1,
                score_data_2,
1647
1648
1649
1650
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1651
            )
1652

1653
1654
1655
1656
1657
1658
1659
1660
1661
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1662
1663
1664
1665

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1666
1667
1668
1669
1670
1671
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1672

1673
1674
1675
1676
1677
1678
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1679
        Args:
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1693
        """
1694
1695
        if level > 0:
            self.reset_prefix_cache()
1696
1697
        self.llm_engine.sleep(level=level)

1698
    def wake_up(self, tags: list[str] | None = None):
1699
        """
1700
1701
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1702

1703
        Args:
1704
1705
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1706
1707
1708
1709
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1710
1711
        """
        self.llm_engine.wake_up(tags)
1712

1713
1714
1715
1716
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1717
            A `MetricSnapshot` instance capturing the current state
1718
1719
1720
1721
1722
1723
1724
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1725
    def _params_to_seq(
1726
        self,
1727
        params: _P | Sequence[_P],
1728
        num_requests: int,
1729
    ) -> Sequence[_P]:
1730
1731
1732
1733
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1734
                    f"and params ({len(params)}) must be the same."
1735
1736
                )

1737
            return params
1738

1739
1740
1741
1742
1743
1744
1745
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1746
1747
1748
1749
1750
1751
1752
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1753
1754
1755
            return lora_request

        return [lora_request] * num_requests
1756

1757
1758
1759
1760
1761
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1762
1763
1764
1765
1766
1767
1768
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
            return priority

        return [0] * num_requests

    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1779
        output_type: type[_O],
1780
1781
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1782
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1783
1784
1785
1786
1787
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]
        seq_priority = self._priority_to_seq(priority, len(prompts))

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_cmpl_one(prompt, tok_kwargs)
                for prompt, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_prompts,
                        use_tqdm=use_tqdm,
                        desc="Rendering prompts",
1806
                    ),
1807
                    seq_tok_kwargs,
1808
                )
1809
            ),
1810
            params=seq_params,
1811
            output_type=output_type,
1812
            use_tqdm=use_tqdm,
1813
1814
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1815
1816
1817
1818
1819
1820
1821
1822
1823
        )

    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1824
        output_type: type[_O],
1825
1826
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1827
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1828
1829
1830
1831
1832
1833
1834
1835
1836
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))
        seq_tok_kwargs = [
            merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
            )
            for param in seq_params
        ]

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenization_kwargs=tok_kwargs,
                    mm_processor_kwargs=mm_processor_kwargs,
                )
                for conversation, tok_kwargs in zip(
                    maybe_tqdm(
                        seq_convs,
                        use_tqdm=use_tqdm,
                        desc="Rendering conversations",
                    ),
                    seq_tok_kwargs,
                )
            ),
            params=seq_params,
1871
            output_type=output_type,
1872
1873
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1874
1875
        )

1876
1877
1878
1879
    def _render_and_run_requests(
        self,
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1880
        output_type: type[_O],
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1898
            params=params,
1899
1900
            lora_requests=lora_requests,
            priorities=priorities,
1901
1902
        )

1903
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1904

1905
    def _render_and_add_requests(
1906
        self,
1907
1908
        prompts: Iterable[ProcessorInputs],
        params: Sequence[SamplingParams | PoolingParams],
1909
        *,
1910
1911
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1912
    ) -> list[str]:
1913
        added_request_ids: list[str] = []
1914

1915
        try:
1916
            for i, prompt in enumerate(prompts):
1917
1918
                request_id = self._add_request(
                    prompt,
1919
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
1920
1921
1922
1923
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
1924
                    priority=0 if priorities is None else priorities[i],
1925
1926
1927
1928
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1929
                self.llm_engine.abort_request(added_request_ids, internal=True)
1930
            raise e
1931

1932
1933
        return added_request_ids

1934
    def _add_request(
nunjunj's avatar
nunjunj committed
1935
        self,
1936
        prompt: ProcessorInputs,
1937
1938
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1939
        priority: int = 0,
1940
    ) -> str:
1941
1942
1943
1944
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1945
        request_id = str(next(self.request_counter))
1946

1947
        return self.llm_engine.add_request(
1948
            request_id,
1949
            prompt,
1950
1951
            params,
            lora_request=lora_request,
1952
            priority=priority,
nunjunj's avatar
nunjunj committed
1953
        )
1954

1955
    def _run_engine(
1956
        self,
1957
        output_type: type[_O] | tuple[type[_O], ...],
1958
1959
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1960
    ) -> list[_O]:
1961
1962
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1963
            num_requests = self.llm_engine.get_num_unfinished_requests()
1964
1965
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1966
1967
1968
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1969
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1970
            )
1971

Zhuohan Li's avatar
Zhuohan Li committed
1972
        # Run the engine.
1973
        outputs: list[_O] = []
1974
1975
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1976
1977
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1978
            for output in step_outputs:
1979
                assert isinstance(output, output_type)
1980
                if output.finished:
1981
                    outputs.append(output)  # type: ignore[arg-type]
1982
                    if use_tqdm:
1983
1984
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1985
                            n = len(output.outputs)
1986
                            assert output.prompt_token_ids is not None
1987
                            total_in_toks += len(output.prompt_token_ids) * n
1988
1989
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1990
1991
1992
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1993
1994
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1995
1996
                                f"output: {out_spd:.2f} toks/s"
                            )
1997
                            pbar.update(n)
1998
1999
                        else:
                            pbar.update(1)
2000
2001
                        if pbar.n == num_requests:
                            pbar.refresh()
2002

2003
2004
        if use_tqdm:
            pbar.close()
2005
2006
2007
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
2008
        return sorted(outputs, key=lambda x: int(x.request_id))
2009

2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr