llm.py 78.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Callable, Iterable, Sequence
6
from pathlib import Path
7
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, overload
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
from vllm.config.quantization import (
    OnlineQuantizationConfigArgs,
)
40
41
42
43
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
44
from vllm.engine.arg_utils import EngineArgs
45
46
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
47
    ChatTemplateConfig,
48
    ChatTemplateContentFormatOption,
49
    load_chat_template,
50
)
51
from vllm.entrypoints.pooling.factories import init_pooling_io_processors
52
from vllm.entrypoints.pooling.scoring.io_processor import ScoringIOProcessor
53
54
from vllm.entrypoints.pooling.scoring.typing import ScoreInput
from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext
55
from vllm.entrypoints.utils import log_non_default_args
56
from vllm.inputs import (
57
    DataPrompt,
58
    EngineInput,
59
60
61
62
    PromptType,
    TextPrompt,
    TokensPrompt,
)
63
from vllm.logger import init_logger
64
from vllm.lora.request import LoRARequest
65
from vllm.model_executor.layers.quantization import QuantizationMethods
66
67
68
69
70
71
72
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
73
from vllm.platforms import current_platform
74
from vllm.pooling_params import PoolingParams
75
from vllm.renderers import ChatParams, merge_kwargs
76
77
78
79
80
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
81
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
82
from vllm.tasks import PoolingTask
83
from vllm.tokenizers import TokenizerLike
yhu422's avatar
yhu422 committed
84
from vllm.usage.usage_lib import UsageContext
85
from vllm.utils.counter import Counter
86
from vllm.utils.mistral import is_mistral_tokenizer
87
from vllm.utils.tqdm_utils import maybe_tqdm
88
from vllm.v1.engine import PauseMode
89
from vllm.v1.engine.llm_engine import LLMEngine
90
from vllm.v1.sample.logits_processor import LogitsProcessor
91

92
93
94
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

95
96
logger = init_logger(__name__)

97
98
99
100
101
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
102
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
103
104
_R = TypeVar("_R", default=Any)

105
106

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
107
108
109
110
111
112
113
114
115
116
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
117
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
118
119
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
120
121
122
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
123
124
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
125
126
127
128
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
129
        allowed_media_domains: If set, only media URLs that belong to this
130
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
134
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
135
136
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
137
        quantization: The method used to quantize the model weights. Currently,
138
            we support "awq", "gptq", and "fp8" (experimental).
139
140
141
142
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
143
144
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
145
146
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
147
        chat_template: The chat template to apply.
148
149
150
151
152
153
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
154
155
156
157
158
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
159
            compared with using gpu_memory_utilization. Note that
160
161
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
162
163
164
165
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
166
167
168
169
170
171
172
173
174
175
176
177
178
        offload_group_size: Prefetch offloading: Group every N layers
            together. Offload last `offload_num_in_group` layers of each group.
            Default is 0 (disabled).
        offload_num_in_group: Prefetch offloading: Number of layers to
            offload per group. Default is 1.
        offload_prefetch_step: Prefetch offloading: Number of layers to
            prefetch ahead. Higher values hide more latency but use more GPU
            memory. Default is 1.
        offload_params: Prefetch offloading: Set of parameter name segments
            to selectively offload. Only parameters whose names contain one of
            these segments will be offloaded (e.g., {"gate_up_proj", "down_proj"}
            for MLP weights, or {"w13_weight", "w2_weight"} for MoE expert
            weights). If None or empty, all parameters are offloaded.
179
180
181
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
182
        enable_return_routed_experts: Whether to return routed experts.
183
184
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
185
        hf_token: The token to use as HTTP bearer authorization for remote files
186
            . If `True`, will use the token generated when running
187
            `hf auth login` (stored in `~/.cache/huggingface/token`).
188
189
190
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
191
192
193
194
195
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
196
197
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
198
        compilation_config: Either an integer or a dictionary. If it is an
199
            integer, it is used as the mode of compilation optimization. If it
200
            is a dictionary, it can specify the full compilation configuration.
201
202
203
204
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
205
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
206

207
208
    Note:
        This class is intended to be used for offline inference. For online
209
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
210
    """
211
212
213
214

    def __init__(
        self,
        model: str,
215
        *,
216
217
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
218
        tokenizer: str | None = None,
219
        tokenizer_mode: TokenizerMode | str = "auto",
220
        skip_tokenizer_init: bool = False,
221
        trust_remote_code: bool = False,
222
        allowed_local_media_path: str = "",
223
        allowed_media_domains: list[str] | None = None,
224
        tensor_parallel_size: int = 1,
225
        dtype: ModelDType = "auto",
226
227
228
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
229
        chat_template: Path | str | None = None,
230
        seed: int = 0,
231
        gpu_memory_utilization: float = 0.9,
232
        cpu_offload_gb: float = 0,
233
234
235
236
        offload_group_size: int = 0,
        offload_num_in_group: int = 1,
        offload_prefetch_step: int = 1,
        offload_params: set[str] | None = None,
237
        enforce_eager: bool = False,
238
        enable_return_routed_experts: bool = False,
239
        disable_custom_all_reduce: bool = False,
240
241
242
243
244
245
246
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
247
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
248
        attention_config: dict[str, Any] | AttentionConfig | None = None,
249
250
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
251
252
253
        quantization_config: dict[str, Any]
        | OnlineQuantizationConfigArgs
        | None = None,
254
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
255
        **kwargs: Any,
256
    ) -> None:
257
        """LLM constructor."""
258

259
260
261
262
263
264
265
266
267
268
269
        if "swap_space" in kwargs:
            kwargs.pop("swap_space")
            import warnings

            warnings.warn(
                "The 'swap_space' parameter is deprecated and ignored. "
                "It will be removed in a future version.",
                DeprecationWarning,
                stacklevel=2,
            )

270
271
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
272

273
274
275
276
277
278
279
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

280
        if "kv_transfer_config" in kwargs and isinstance(
281
282
            kwargs["kv_transfer_config"], dict
        ):
283
            from vllm.config.kv_transfer import KVTransferConfig
284

285
286
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
287
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
288
289
290
291
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
292
293
294
                    raw_config_dict,
                    e,
                )
295
296
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
297
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
298

299
300
301
        if hf_overrides is None:
            hf_overrides = {}

302
303
304
305
306
307
308
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
309

310
311
312
313
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
314
        else:
315
316
317
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
318

319
320
321
322
323
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
324

325
        # warn about single-process data parallel usage.
326
327
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
328
329
330
331
332
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
333
            raise ValueError(
334
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
335
336
337
338
339
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
340
        engine_args = EngineArgs(
341
            model=model,
342
343
            runner=runner,
            convert=convert,
344
            tokenizer=tokenizer,
345
            tokenizer_mode=tokenizer_mode,
346
            skip_tokenizer_init=skip_tokenizer_init,
347
            trust_remote_code=trust_remote_code,
348
            allowed_local_media_path=allowed_local_media_path,
349
            allowed_media_domains=allowed_media_domains,
350
351
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
352
            quantization=quantization,
353
            revision=revision,
354
            tokenizer_revision=tokenizer_revision,
355
356
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
357
            kv_cache_memory_bytes=kv_cache_memory_bytes,
358
            cpu_offload_gb=cpu_offload_gb,
359
360
361
362
            offload_group_size=offload_group_size,
            offload_num_in_group=offload_num_in_group,
            offload_prefetch_step=offload_prefetch_step,
            offload_params=offload_params or set(),
363
            enforce_eager=enforce_eager,
364
            enable_return_routed_experts=enable_return_routed_experts,
365
            disable_custom_all_reduce=disable_custom_all_reduce,
366
            hf_token=hf_token,
367
            hf_overrides=hf_overrides,
368
            mm_processor_kwargs=mm_processor_kwargs,
369
            pooler_config=pooler_config,
370
            structured_outputs_config=structured_outputs_instance,
371
            profiler_config=profiler_config_instance,
372
            attention_config=attention_config_instance,
373
            compilation_config=compilation_config_instance,
374
            quantization_config=quantization_config,
375
            logits_processors=logits_processors,
376
377
            **kwargs,
        )
378

379
380
        log_non_default_args(engine_args)

381
        self.llm_engine = LLMEngine.from_engine_args(
382
383
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
384
        self.model_config = self.llm_engine.model_config
385
        self.engine_class = type(self.llm_engine)
386

387
        self.request_counter = Counter()
388
        self.default_sampling_params: dict[str, Any] | None = None
389

390
        supported_tasks = self.llm_engine.get_supported_tasks()
391
        self.supported_tasks = supported_tasks
392
393
394
        self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
        if self.pooling_task is not None:
            logger.info("Supported pooling task: %s", self.pooling_task)
395

396
        self.runner_type = self.model_config.runner_type
397
        self.renderer = self.llm_engine.renderer
398
        self.chat_template = load_chat_template(chat_template)
399
        self.input_processor = self.llm_engine.input_processor
400
        self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
401
        self.pooling_io_processors = init_pooling_io_processors(
402
            supported_tasks=supported_tasks,
403
            vllm_config=self.llm_engine.vllm_config,
404
405
406
            renderer=self.renderer,
            chat_template_config=self.chat_template_config,
        )
407
408
409
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

410
411
412
413
414
    @classmethod
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLM":
        """Create an LLM instance from EngineArgs."""
        return cls(**vars(engine_args))

415
    def get_tokenizer(self) -> TokenizerLike:
416
        return self.llm_engine.get_tokenizer()
417

418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

435
    def reset_mm_cache(self) -> None:
436
        self.renderer.clear_mm_cache()
437
438
        self.llm_engine.reset_mm_cache()

439
    def get_default_sampling_params(self) -> SamplingParams:
440
        if self.default_sampling_params is None:
441
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
442
443
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
444
445
        return SamplingParams()

446
447
    def generate(
        self,
448
449
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
450
        *,
451
        use_tqdm: bool | Callable[..., tqdm] = True,
452
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
453
        priority: list[int] | None = None,
454
        tokenization_kwargs: dict[str, Any] | None = None,
455
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
456
457
        """Generates the completions for the input prompts.

458
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
459
460
461
462
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
463
            prompts: The prompts to the LLM. You may pass a sequence of prompts
464
                for batch inference. See [PromptType][vllm.inputs.PromptType]
465
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
466
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
467
468
469
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
470
                prompts and it is paired one by one with the prompt.
471
472
473
474
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
475
            lora_request: LoRA request to use for generation, if any.
476
477
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
478
479
480
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
481
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
482
483

        Returns:
484
            A list of `RequestOutput` objects containing the
485
486
            generated completions in the same order as the input prompts.
        """
487
        runner_type = self.model_config.runner_type
488
        if runner_type != "generate":
489
490
491
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
492
493
                "generative model."
            )
494

495
        if sampling_params is None:
496
            sampling_params = self.get_default_sampling_params()
497

498
        return self._run_completion(
499
            prompts=prompts,
500
            params=sampling_params,
501
            output_type=RequestOutput,
502
            use_tqdm=use_tqdm,
503
            lora_request=lora_request,
504
            tokenization_kwargs=tokenization_kwargs,
505
506
            priority=priority,
        )
507

508
509
510
511
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
512
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
534
        runner_type = self.model_config.runner_type
535
536
537
538
539
540
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

541
542
543
544
545
546
547
        return self._add_completion_requests(
            prompts=prompts,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
548
549
        )

550
    @overload
551
552
    def wait_for_completion(
        self,
553
        *,
554
        use_tqdm: bool | Callable[..., tqdm] = True,
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
571
572
573
574
575
576
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
577
            output_type: The expected output type, defaults to RequestOutput.
578
579
580
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
581
            A list of output objects for all completed requests.
582
        """
583
584
585
586
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
587

Cyrus Leung's avatar
Cyrus Leung committed
588
    def _resolve_mm_lora(
589
        self,
590
        prompt: EngineInput,
591
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
592
593
594
595
596
597
598
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
599
600
            return lora_request

601
602
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
603
604
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
605

606
607
608
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
609
610
611
612
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
613
614
                intersection,
            )
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
630
631
                    "lora_request as we only apply one LoRARequest per prompt"
                )
632
633
634
635
636
637
638
639
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

640
641
    def collective_rpc(
        self,
642
643
        method: str | Callable[..., _R],
        timeout: float | None = None,
644
        args: tuple = (),
645
        kwargs: dict[str, Any] | None = None,
646
    ) -> list[_R]:
647
648
649
650
651
652
653
654
655
656
657
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
658
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
659
660
661
662
663
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
664

665
666
667
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
668
        """
669
670

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
671
672

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
673
        """
674
675
        Run a function directly on the model inside each worker,
        returning the result for each of them.
676
677
678
679
680
681

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
682
        """
683
        return self.llm_engine.apply_model(func)
684

685
686
    def beam_search(
        self,
687
        prompts: list[TokensPrompt | TextPrompt],
688
        params: BeamSearchParams,
689
        lora_request: list[LoRARequest] | LoRARequest | None = None,
690
        use_tqdm: bool = False,
691
        concurrency_limit: int | None = None,
692
    ) -> list[BeamSearchOutput]:
693
694
695
696
697
698
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
699
            params: The beam search parameters.
700
            lora_request: LoRA request to use for generation, if any.
701
            use_tqdm: Whether to use tqdm to display the progress bar.
702
703
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
704
        """
705
706
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
707
708
709
710
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
711
712
        length_penalty = params.length_penalty

713
714
715
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
716

717
718
        engine_inputs = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
719

720
721
722
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
723
724
                "Disabling progress bar."
            )
725
726
727
            use_tqdm = False

        if concurrency_limit is None:
728
            concurrency_limit = len(engine_inputs)
729

730
731
732
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
733
        sampling_params = SamplingParams(
734
735
736
737
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
738
        )
739
        instances: list[BeamSearchInstance] = []
740

741
        for lora_req, prompt in zip(lora_requests, engine_inputs):
742
743
744
745
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
746

747
            instances.append(
748
                BeamSearchInstance(
749
                    prompt,
750
751
                    lora_request=lora_req,
                    logprobs=None,
752
753
                ),
            )
754

755
        for prompt_start in range(0, len(instances), concurrency_limit):
756
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
757
758
759

            token_iter = range(max_tokens)
            if use_tqdm:
760
761
762
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
763
764
765
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
766
767
                    "reflect instance-level progress."
                )
768
769
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
770
771
                    sum((instance.beams for instance in instances_batch), [])
                )
772
773
                pos = [0] + list(
                    itertools.accumulate(
774
775
776
                        len(instance.beams) for instance in instances_batch
                    )
                )
777
                instance_start_and_end: list[tuple[int, int]] = list(
778
779
                    zip(pos[:-1], pos[1:])
                )
780
781
782
783
784
785

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
786
                output = self._render_and_run_requests(
787
788
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
789
                    output_type=RequestOutput,
790
                    lora_requests=[beam.lora_request for beam in all_beams],
791
792
                    use_tqdm=False,
                )
793

794
795
796
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
797
798
799
800
801
802
803
804
805
806
807
808
809
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
810
                                    current_beam.orig_prompt,
811
                                    tokens=current_beam.tokens + [token_id],
812
                                    logprobs=current_beam.logprobs + [logprobs],
813
                                    lora_request=current_beam.lora_request,
814
815
816
817
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

818
                                if token_id == eos_token_id and not ignore_eos:
819
820
821
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
822
823
824
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
825
                    instance.beams = sorted_beams[:beam_width]
826
827
828
829

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
830
831
832
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
833
834
835
836
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
837

838
839
840
841
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

842
    def _preprocess_cmpl(
843
        self,
844
        prompts: Sequence[PromptType],
845
        tokenization_kwargs: dict[str, Any] | None = None,
846
    ) -> Sequence[EngineInput]:
847
848
849
850
851
852
853
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
854
            A list of `EngineInput` objects ready to be passed into LLMEngine.
855
        """
856
        renderer = self.renderer
857
858
        model_config = self.model_config

859
860
861
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
862
863
864
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
865

866
        return renderer.render_cmpl(parsed_prompts, tok_params)
867

868
869
870
871
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
872
873
874
    ) -> EngineInput:
        (engine_input,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_input
875

876
877
    def _preprocess_chat(
        self,
878
        conversations: Sequence[list[ChatCompletionMessageParam]],
879
        chat_template: str | None = None,
880
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
881
        chat_template_kwargs: dict[str, Any] | None = None,
882
        add_generation_prompt: bool = True,
883
        continue_final_message: bool = False,
884
        tools: list[dict[str, Any]] | None = None,
885
        tokenization_kwargs: dict[str, Any] | None = None,
886
        mm_processor_kwargs: dict[str, Any] | None = None,
887
    ) -> Sequence[EngineInput]:
nunjunj's avatar
nunjunj committed
888
        """
889
890
891
892
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
893
894

        Returns:
895
            A list of `EngineInput` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
896
        """
897
        renderer = self.renderer
898

899
900
901
902
903
904
905
906
907
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
908
                    tokenize=is_mistral_tokenizer(renderer.tokenizer),
909
910
911
                ),
            ),
        )
912
913
914
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
915

916
        _, engine_inputs = renderer.render_chat(
917
918
919
920
921
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
922

923
        return engine_inputs
924

925
926
927
928
929
930
931
932
933
934
935
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
936
937
    ) -> EngineInput:
        (engine_input,) = self._preprocess_chat(
938
939
940
941
942
943
944
945
946
947
948
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

949
        return engine_input
950

951
952
    def chat(
        self,
953
        messages: list[ChatCompletionMessageParam]
954
955
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
956
        use_tqdm: bool | Callable[..., tqdm] = True,
957
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
958
        chat_template: str | None = None,
959
960
961
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
962
963
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
964
        tokenization_kwargs: dict[str, Any] | None = None,
965
        mm_processor_kwargs: dict[str, Any] | None = None,
966
967
968
969
970
971
972
973
974
975
976
977
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
978
            messages: A sequence of conversations or a single conversation.
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
1010
1011
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
1012
1013
1014
1015
1016

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

1029
        return self._run_chat(
1030
1031
            messages=messages,
            params=sampling_params,
1032
            output_type=RequestOutput,
1033
1034
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1035
1036
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1037
            chat_template_kwargs=chat_template_kwargs,
1038
1039
1040
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1041
            tokenization_kwargs=tokenization_kwargs,
1042
1043
1044
            mm_processor_kwargs=mm_processor_kwargs,
        )

1045
1046
    def encode(
        self,
1047
1048
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1049
        *,
1050
1051
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1052
        pooling_task: PoolingTask | None = None,
1053
        tokenization_kwargs: dict[str, Any] | None = None,
1054
    ) -> list[PoolingRequestOutput]:
1055
1056
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1057

1058
        This class automatically batches the given prompts, considering
1059
1060
1061
1062
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1063
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1064
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1065
                for more details about the format of each prompt.
1066
1067
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1068
1069
1070
1071
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1072
            lora_request: LoRA request to use for generation, if any.
1073
            pooling_task: Override the pooling task to use.
1074
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1075
1076

        Returns:
1077
            A list of `PoolingRequestOutput` objects containing the
1078
            pooled hidden states in the same order as the input prompts.
1079
        """
1080

1081
1082
1083
1084
        if isinstance(prompts, dict) and "data" in prompts and pooling_task != "plugin":
            raise ValueError(
                "The 'data' field is only supported for the 'plugin' pooling task."
            )
1085
        self._verify_pooling_task(pooling_task)
1086
        assert pooling_task is not None and pooling_task in self.pooling_io_processors
1087

1088
        io_processor = self.pooling_io_processors[pooling_task]
1089

1090
1091
        if pooling_params is None:
            pooling_params = PoolingParams()
1092

1093
1094
1095
1096
1097
        ctx = OfflineInputsContext(
            prompts=prompts,
            pooling_params=pooling_params,
            tokenization_kwargs=tokenization_kwargs,
        )
1098

1099
1100
1101
        engine_inputs = io_processor.pre_process_offline(ctx)
        n_inputs = len(engine_inputs)
        assert ctx.pooling_params is not None
1102

1103
        params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
1104

1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
        for param in params_seq:
            if param.task is None:
                param.task = pooling_task
            elif pooling_task == "plugin":
                # `plugin` task uses io_processor.parse_request to verify inputs.
                # We actually allow plugin to overwrite pooling_task.
                pass
            elif param.task != pooling_task:
                msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                raise ValueError(msg)
1115

1116
1117
        seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
        seq_priority = self._priority_to_seq(None, n_inputs)
1118

1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
        self._render_and_add_requests(
            prompts=engine_inputs,
            params=params_seq,
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
        outputs = io_processor.post_process_offline(
            ctx=OfflineOutputsContext(outputs=outputs)
        )
1130
        return outputs
1131

1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
    def _verify_pooling_task(self, pooling_task: PoolingTask | None):
        if self.runner_type != "pooling":
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model."
            )

        if pooling_task is None:
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )

        if (
            pooling_task in ("embed", "token_embed")
            and pooling_task not in self.supported_tasks
        ):
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`."
            )

        if (
            pooling_task in ("classify", "token_classify")
            and pooling_task not in self.supported_tasks
        ):
            raise ValueError(
                "Classification API is not supported by this model. "
                "Try converting the model using `--convert classify`."
            )

        # plugin task uses io_processor.parse_request to verify inputs
        if pooling_task != "plugin" and pooling_task != self.pooling_task:
            if pooling_task not in self.supported_tasks:
                raise ValueError(
                    f"Unsupported task: {pooling_task!r} "
                    f"Supported tasks: {self.supported_tasks}"
                )
            else:
                logger.warning_once(
                    "Pooling multitask support is deprecated and will "
                    "be removed in v0.20. When the default pooling task is "
                    "not what you want, you need to manually specify it "
                    'via PoolerConfig(task="%s"). ',
                    pooling_task,
                )

1191
1192
1193
1194
1195
1196
1197
1198
        if pooling_task == "plugin" and "plugin" not in self.pooling_io_processors:
            raise ValueError(
                "No IOProcessor plugin installed. Please refer "
                "to the documentation and to the "
                "'prithvi_geospatial_mae_io_processor' "
                "offline inference example for more details."
            )

1199
1200
    def embed(
        self,
1201
        prompts: PromptType | Sequence[PromptType],
1202
        *,
1203
1204
1205
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1206
        tokenization_kwargs: dict[str, Any] | None = None,
1207
    ) -> list[EmbeddingRequestOutput]:
1208
1209
1210
1211
1212
1213
1214
1215
1216
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1217
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1218
                for more details about the format of each prompt.
1219
1220
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1221
1222
1223
1224
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1225
            lora_request: LoRA request to use for generation, if any.
1226
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1227
1228

        Returns:
1229
            A list of `EmbeddingRequestOutput` objects containing the
1230
1231
1232
            embedding vectors in the same order as the input prompts.
        """

1233
1234
1235
1236
1237
1238
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1239
            tokenization_kwargs=tokenization_kwargs,
1240
        )
1241
1242
1243
1244
1245

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1246
        prompts: PromptType | Sequence[PromptType],
1247
        *,
1248
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1249
        use_tqdm: bool | Callable[..., tqdm] = True,
1250
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1251
        tokenization_kwargs: dict[str, Any] | None = None,
1252
    ) -> list[ClassificationRequestOutput]:
1253
1254
1255
1256
1257
1258
1259
1260
1261
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1262
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1263
                for more details about the format of each prompt.
1264
1265
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1266
1267
1268
1269
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1270
            lora_request: LoRA request to use for generation, if any.
1271
1272
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1273
        Returns:
1274
            A list of `ClassificationRequestOutput` objects containing the
1275
1276
1277
            embedding vectors in the same order as the input prompts.
        """

1278
1279
1280
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1281
            pooling_params=pooling_params,
1282
1283
            lora_request=lora_request,
            pooling_task="classify",
1284
            tokenization_kwargs=tokenization_kwargs,
1285
        )
1286
1287
1288

        return [ClassificationRequestOutput.from_base(item) for item in items]

1289
1290
    def reward(
        self,
1291
        prompts: PromptType | Sequence[PromptType],
1292
1293
        /,
        *,
1294
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1295
1296
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1297
        tokenization_kwargs: dict[str, Any] | None = None,
1298
1299
1300
1301
1302
1303
1304
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1305
                for more details about the format of each prompt.
1306
1307
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1308
1309
1310
1311
1312
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1313
1314
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1315
1316
1317
1318
1319
1320
1321
1322
1323
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """
        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
1324
            pooling_task="token_classify",
1325
            tokenization_kwargs=tokenization_kwargs,
1326
1327
        )

1328
1329
    def score(
        self,
1330
1331
        data_1: ScoreInput | list[ScoreInput],
        data_2: ScoreInput | list[ScoreInput],
1332
        /,
1333
        *,
1334
1335
1336
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1337
        tokenization_kwargs: dict[str, Any] | None = None,
1338
        chat_template: str | None = None,
1339
    ) -> list[ScoringRequestOutput]:
1340
1341
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1342

1343
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1344
1345
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1346
        The input pairs are used to build a list of prompts for the
1347
1348
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1349
1350
1351
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1352
        appropriate multi-modal models. For multi-modal inputs, ensure the
1353
        prompt structure matches the model's expected input format.
1354
1355

        Args:
1356
1357
1358
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1359
                the `data_2` list.
1360
            data_2: The data to pair with the query to form the input to
1361
                the LLM. Can be text or multi-modal data. See [PromptType]
1362
                [vllm.inputs.PromptType] for more details about the format of
1363
                each prompt.
1364
1365
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1366
1367
1368
1369
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1370
            lora_request: LoRA request to use for generation, if any.
1371
1372
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1373
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1374
        Returns:
1375
            A list of `ScoringRequestOutput` objects containing the
1376
1377
            generated scores in the same order as the input prompts.
        """
1378

1379
        if self.runner_type != "pooling":
1380
1381
1382
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1383
1384
                "pooling model."
            )
1385

1386
        score_type = self.model_config.score_type
1387
1388
1389
        if (
            score_type == "cross-encoder"
            and getattr(self.model_config.hf_config, "num_labels", 0) != 1
1390
        ):
1391
            raise ValueError("Scoring API is only enabled for num_labels == 1.")
1392

1393
1394
        if score_type is None or score_type not in self.pooling_io_processors:
            raise ValueError("This model does not support the Scoring API.")
1395

1396
1397
        io_processor = self.pooling_io_processors[score_type]
        assert isinstance(io_processor, ScoringIOProcessor)
1398

1399
1400
        pooling_task = io_processor.pooling_task
        scoring_data = io_processor.valid_inputs(data_1, data_2)
1401
        n_queries = len(scoring_data.data_1)
1402

1403
1404
1405
        if pooling_params is None:
            pooling_params = PoolingParams()

1406
1407
1408
1409
1410
        ctx = OfflineInputsContext(
            prompts=scoring_data,
            pooling_params=pooling_params,
            tokenization_kwargs=tokenization_kwargs,
            chat_template=chat_template,
1411
            n_queries=n_queries,
1412
        )
1413

1414
1415
        engine_inputs = io_processor.pre_process_offline(ctx)
        n_inputs = len(engine_inputs)
1416

1417
1418
        seq_lora_requests = self._lora_request_to_seq(lora_request, n_inputs)
        params_seq = self._params_to_seq(ctx.pooling_params, n_inputs)
1419
1420
1421
1422
1423
1424
1425
1426

        for param in params_seq:
            if param.task is None:
                param.task = pooling_task
            elif param.task != pooling_task:
                msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                raise ValueError(msg)

1427
        seq_priority = self._priority_to_seq(None, n_inputs)
1428
1429

        self._render_and_add_requests(
1430
            prompts=engine_inputs,
1431
1432
1433
1434
1435
1436
1437
            params=params_seq,
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
        outputs = io_processor.post_process_offline(
1438
            ctx=OfflineOutputsContext(outputs=outputs, n_queries=n_queries),
1439
1440
1441
        )

        return [ScoringRequestOutput.from_base(item) for item in outputs]
1442

1443
1444
1445
1446
1447
1448
1449
1450
1451
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1452
1453
1454
1455

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1456
1457
1458
1459
1460
1461
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1462

1463
    def sleep(self, level: int = 1, mode: PauseMode = "abort"):
1464
1465
1466
1467
1468
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1469
        Args:
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1483
1484
            mode: How to handle any existing requests, can be "abort", "wait",
                or "keep".
1485
        """
1486
        self.llm_engine.sleep(level=level, mode=mode)
1487

1488
    def wake_up(self, tags: list[str] | None = None):
1489
        """
1490
1491
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1492

1493
        Args:
1494
1495
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1496
1497
1498
1499
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1500
1501
        """
        self.llm_engine.wake_up(tags)
1502

1503
1504
1505
1506
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1507
            A `MetricSnapshot` instance capturing the current state
1508
1509
1510
1511
1512
1513
1514
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1515
    def _params_to_seq(
1516
        self,
1517
        params: _P | Sequence[_P],
1518
        num_requests: int,
1519
    ) -> Sequence[_P]:
1520
1521
1522
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
1523
                    f"The lengths of prompts ({num_requests}) "
1524
                    f"and params ({len(params)}) must be the same."
1525
1526
                )

1527
            return params
1528

1529
1530
1531
1532
1533
1534
1535
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1536
1537
1538
1539
1540
1541
1542
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1543
1544
1545
            return lora_request

        return [lora_request] * num_requests
1546

1547
1548
1549
1550
1551
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1552
1553
1554
1555
1556
1557
1558
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1559
1560
1561
1562
            return priority

        return [0] * num_requests

1563
    def _add_completion_requests(
1564
1565
1566
1567
1568
1569
1570
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1571
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1572
1573
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
1574
    ) -> list[str]:
1575
1576
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1577
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
1578
        seq_priority = self._priority_to_seq(priority, len(seq_prompts))
1579

1580
        return self._render_and_add_requests(
1581
            prompts=(
1582
1583
1584
1585
1586
                self._preprocess_cmpl_one(prompt, tokenization_kwargs)
                for prompt in maybe_tqdm(
                    seq_prompts,
                    use_tqdm=use_tqdm,
                    desc="Rendering prompts",
1587
                )
1588
            ),
1589
            params=seq_params,
1590
1591
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1592
1593
        )

1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        output_type: type[_O],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        self._add_completion_requests(
            prompts=prompts,
            params=params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
        )
        return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)

1617
1618
1619
1620
1621
1622
1623
    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1624
        output_type: type[_O],
1625
1626
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1627
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1628
1629
1630
1631
1632
1633
1634
1635
1636
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1637
1638
1639
1640
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))

1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
        # When thinking is enabled or tools are provided, and the model
        # uses special tokens for structured output (e.g. Gemma4's
        # <|channel>, <|tool_call>, <|"|>), automatically set
        # skip_special_tokens=False so these tokens are preserved in
        # output.text for downstream parsing.
        needs_parsing = (
            chat_template_kwargs and chat_template_kwargs.get("enable_thinking")
        ) or tools
        if needs_parsing:
            self._adjust_params_for_parsing(seq_params)

1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
1662
                    tokenization_kwargs=tokenization_kwargs,
1663
1664
                    mm_processor_kwargs=mm_processor_kwargs,
                )
1665
1666
1667
1668
                for conversation in maybe_tqdm(
                    seq_convs,
                    use_tqdm=use_tqdm,
                    desc="Rendering conversations",
1669
1670
1671
                )
            ),
            params=seq_params,
1672
            output_type=output_type,
1673
1674
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1675
1676
        )

1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
    def _adjust_params_for_parsing(
        self, params: Sequence[SamplingParams | PoolingParams]
    ) -> None:
        """Set ``skip_special_tokens=False`` when the model encodes
        structured output syntax as special tokens.

        Models like Gemma4 register thinking delimiters
        (``<|channel>``/``<channel|>``) and tool call tokens
        (``<|tool_call>``/``<tool_call|>``/``<|"|>``) as special tokens.
        The default ``skip_special_tokens=True`` strips them from
        ``output.text``, breaking parsing of both reasoning blocks and
        tool calls.

        This is a no-op for models whose structured tokens are regular
        text tokens (e.g. DeepSeek's ``<think>``/``</think>``).
        """
        # The offline API currently lacks a unified rendering pipeline.
        # Until the planned Renderer refactor is complete, we hardcode
        # this token preservation logic specifically for Gemma4 models
        # to avoid regressions on other models.
        hf_config = getattr(self.model_config, "hf_config", None)
        architectures = getattr(hf_config, "architectures", [])

        if any("Gemma4" in arch for arch in architectures):
            tokenizer = self.renderer.get_tokenizer()
            vocab = tokenizer.get_vocab()
            special_ids = set(getattr(tokenizer, "all_special_ids", []))

            # Tokens used for thinking delimiters and tool call syntax
            # that some models (Gemma4) register as special tokens.
            structured_tokens = (
                "<|channel>",
                "<channel|>",  # thinking delimiters
                "<|tool_call>",
                "<tool_call|>",  # tool call delimiters
                '<|"|>',  # string quoting in tool args
            )
            needs_special = any(
                vocab.get(tok) in special_ids
                for tok in structured_tokens
                if tok in vocab
            )
            if needs_special:
                for sp in params:
                    if isinstance(sp, SamplingParams) and sp.skip_special_tokens:
                        sp.skip_special_tokens = False

1724
1725
    def _render_and_run_requests(
        self,
1726
        prompts: Iterable[EngineInput],
1727
        params: Sequence[SamplingParams | PoolingParams],
1728
        output_type: type[_O],
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1746
            params=params,
1747
1748
            lora_requests=lora_requests,
            priorities=priorities,
1749
1750
        )

1751
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1752

1753
    def _render_and_add_requests(
1754
        self,
1755
        prompts: Iterable[EngineInput],
1756
        params: Sequence[SamplingParams | PoolingParams],
1757
        *,
1758
1759
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1760
    ) -> list[str]:
1761
        added_request_ids: list[str] = []
1762

1763
        try:
1764
            for i, prompt in enumerate(prompts):
1765
1766
                request_id = self._add_request(
                    prompt,
1767
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
1768
1769
1770
1771
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
1772
                    priority=0 if priorities is None else priorities[i],
1773
1774
1775
1776
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1777
                self.llm_engine.abort_request(added_request_ids, internal=True)
1778
            raise e
1779

1780
1781
        return added_request_ids

1782
    def _add_request(
nunjunj's avatar
nunjunj committed
1783
        self,
1784
        prompt: EngineInput,
1785
1786
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1787
        priority: int = 0,
1788
    ) -> str:
1789
1790
1791
1792
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1793
        request_id = str(next(self.request_counter))
1794

1795
        return self.llm_engine.add_request(
1796
            request_id,
1797
            prompt,
1798
1799
            params,
            lora_request=lora_request,
1800
            priority=priority,
nunjunj's avatar
nunjunj committed
1801
        )
1802

1803
    def _run_engine(
1804
        self,
1805
        output_type: type[_O] | tuple[type[_O], ...],
1806
1807
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1808
    ) -> list[_O]:
1809
1810
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1811
            num_requests = self.llm_engine.get_num_unfinished_requests()
1812
1813
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1814
1815
1816
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1817
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1818
            )
1819

Zhuohan Li's avatar
Zhuohan Li committed
1820
        # Run the engine.
1821
        outputs: list[_O] = []
1822
1823
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1824
1825
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1826
            for output in step_outputs:
1827
                assert isinstance(output, output_type)
1828
                if output.finished:
1829
                    outputs.append(output)  # type: ignore[arg-type]
1830
                    if use_tqdm:
1831
1832
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1833
                            n = len(output.outputs)
1834
                            assert output.prompt_token_ids is not None
1835
                            total_in_toks += len(output.prompt_token_ids) * n
1836
1837
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1838
1839
1840
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1841
1842
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1843
1844
                                f"output: {out_spd:.2f} toks/s"
                            )
1845
                            pbar.update(n)
1846
1847
                        else:
                            pbar.update(1)
1848
1849
                        if pbar.n == num_requests:
                            pbar.refresh()
1850

1851
1852
        if use_tqdm:
            pbar.close()
1853
1854
1855
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1856
        return sorted(outputs, key=lambda x: int(x.request_id))
1857

1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr