llm.py 77.8 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Callable, Iterable, Sequence
6
from pathlib import Path
7
from typing import TYPE_CHECKING, Any
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar, overload
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
38
39
40
from vllm.distributed.weight_transfer.base import (
    WeightTransferInitRequest,
    WeightTransferUpdateRequest,
)
41
from vllm.engine.arg_utils import EngineArgs
42
43
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
44
    ChatTemplateConfig,
45
    ChatTemplateContentFormatOption,
46
    load_chat_template,
47
)
48
from vllm.entrypoints.pooling.io_processor_factories import init_pooling_io_processors
49
50
from vllm.entrypoints.pooling.scoring.io_processor import (
    ScoringIOProcessor,
51
)
52
53
from vllm.entrypoints.pooling.scoring.typing import ScoreInput
from vllm.entrypoints.pooling.typing import OfflineInputsContext, OfflineOutputsContext
54
from vllm.entrypoints.utils import log_non_default_args
55
from vllm.inputs import (
56
    DataPrompt,
57
    EngineInput,
58
59
60
61
    PromptType,
    TextPrompt,
    TokensPrompt,
)
62
from vllm.logger import init_logger
63
from vllm.lora.request import LoRARequest
64
from vllm.model_executor.layers.quantization import QuantizationMethods
65
66
67
68
69
70
71
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
72
from vllm.platforms import current_platform
73
from vllm.pooling_params import PoolingParams
74
from vllm.renderers import ChatParams, merge_kwargs
75
76
77
78
79
from vllm.renderers.inputs.preprocess import (
    conversation_to_seq,
    parse_model_prompt,
    prompt_to_seq,
)
80
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
81
from vllm.tasks import PoolingTask
82
from vllm.tokenizers import TokenizerLike
yhu422's avatar
yhu422 committed
83
from vllm.usage.usage_lib import UsageContext
84
from vllm.utils.counter import Counter
85
from vllm.utils.mistral import is_mistral_tokenizer
86
from vllm.utils.tqdm_utils import maybe_tqdm
87
from vllm.v1.engine import PauseMode
88
from vllm.v1.engine.llm_engine import LLMEngine
89
from vllm.v1.sample.logits_processor import LogitsProcessor
90

91
92
93
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

94
95
logger = init_logger(__name__)

96
97
98
99
100
_O = TypeVar(
    "_O",
    bound=RequestOutput | PoolingRequestOutput,
    default=RequestOutput | PoolingRequestOutput,
)
101
_P = TypeVar("_P", bound=SamplingParams | PoolingParams | None)
102
103
_R = TypeVar("_R", default=Any)

104
105

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
106
107
108
109
110
111
112
113
114
115
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
116
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
117
118
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
119
120
121
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
122
123
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
124
125
126
127
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
128
        allowed_media_domains: If set, only media URLs that belong to this
129
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
130
131
132
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
133
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
134
135
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
136
        quantization: The method used to quantize the model weights. Currently,
137
            we support "awq", "gptq", and "fp8" (experimental).
138
139
140
141
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
142
143
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
144
145
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
146
        chat_template: The chat template to apply.
147
148
149
150
151
152
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
153
154
155
156
157
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
158
            compared with using gpu_memory_utilization. Note that
159
160
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
161
162
163
164
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
165
166
167
168
169
170
171
172
173
174
175
176
177
        offload_group_size: Prefetch offloading: Group every N layers
            together. Offload last `offload_num_in_group` layers of each group.
            Default is 0 (disabled).
        offload_num_in_group: Prefetch offloading: Number of layers to
            offload per group. Default is 1.
        offload_prefetch_step: Prefetch offloading: Number of layers to
            prefetch ahead. Higher values hide more latency but use more GPU
            memory. Default is 1.
        offload_params: Prefetch offloading: Set of parameter name segments
            to selectively offload. Only parameters whose names contain one of
            these segments will be offloaded (e.g., {"gate_up_proj", "down_proj"}
            for MLP weights, or {"w13_weight", "w2_weight"} for MoE expert
            weights). If None or empty, all parameters are offloaded.
178
179
180
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
181
        enable_return_routed_experts: Whether to return routed experts.
182
183
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
184
        hf_token: The token to use as HTTP bearer authorization for remote files
185
            . If `True`, will use the token generated when running
186
            `hf auth login` (stored in `~/.cache/huggingface/token`).
187
188
189
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
190
191
192
193
194
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
195
196
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
197
        compilation_config: Either an integer or a dictionary. If it is an
198
            integer, it is used as the mode of compilation optimization. If it
199
            is a dictionary, it can specify the full compilation configuration.
200
201
202
203
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
204
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
205

206
207
    Note:
        This class is intended to be used for offline inference. For online
208
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
209
    """
210
211
212
213

    def __init__(
        self,
        model: str,
214
        *,
215
216
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
217
        tokenizer: str | None = None,
218
        tokenizer_mode: TokenizerMode | str = "auto",
219
        skip_tokenizer_init: bool = False,
220
        trust_remote_code: bool = False,
221
        allowed_local_media_path: str = "",
222
        allowed_media_domains: list[str] | None = None,
223
        tensor_parallel_size: int = 1,
224
        dtype: ModelDType = "auto",
225
226
227
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
228
        chat_template: Path | str | None = None,
229
        seed: int = 0,
230
        gpu_memory_utilization: float = 0.9,
231
        cpu_offload_gb: float = 0,
232
233
234
235
        offload_group_size: int = 0,
        offload_num_in_group: int = 1,
        offload_prefetch_step: int = 1,
        offload_params: set[str] | None = None,
236
        enforce_eager: bool = False,
237
        enable_return_routed_experts: bool = False,
238
        disable_custom_all_reduce: bool = False,
239
240
241
242
243
244
245
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
246
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
247
        attention_config: dict[str, Any] | AttentionConfig | None = None,
248
249
250
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
251
        **kwargs: Any,
252
    ) -> None:
253
        """LLM constructor."""
254

255
256
257
258
259
260
261
262
263
264
265
        if "swap_space" in kwargs:
            kwargs.pop("swap_space")
            import warnings

            warnings.warn(
                "The 'swap_space' parameter is deprecated and ignored. "
                "It will be removed in a future version.",
                DeprecationWarning,
                stacklevel=2,
            )

266
267
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
268

269
270
271
272
273
274
275
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

276
        if "kv_transfer_config" in kwargs and isinstance(
277
278
            kwargs["kv_transfer_config"], dict
        ):
279
            from vllm.config.kv_transfer import KVTransferConfig
280

281
282
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
283
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
284
285
286
287
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
288
289
290
                    raw_config_dict,
                    e,
                )
291
292
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
293
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
294

295
296
297
        if hf_overrides is None:
            hf_overrides = {}

298
299
300
301
302
303
304
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
305

306
307
308
309
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
310
        else:
311
312
313
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
314

315
316
317
318
319
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
320

321
        # warn about single-process data parallel usage.
322
323
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
324
325
326
327
328
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
329
            raise ValueError(
330
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
331
332
333
334
335
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
336
        engine_args = EngineArgs(
337
            model=model,
338
339
            runner=runner,
            convert=convert,
340
            tokenizer=tokenizer,
341
            tokenizer_mode=tokenizer_mode,
342
            skip_tokenizer_init=skip_tokenizer_init,
343
            trust_remote_code=trust_remote_code,
344
            allowed_local_media_path=allowed_local_media_path,
345
            allowed_media_domains=allowed_media_domains,
346
347
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
348
            quantization=quantization,
349
            revision=revision,
350
            tokenizer_revision=tokenizer_revision,
351
352
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
353
            kv_cache_memory_bytes=kv_cache_memory_bytes,
354
            cpu_offload_gb=cpu_offload_gb,
355
356
357
358
            offload_group_size=offload_group_size,
            offload_num_in_group=offload_num_in_group,
            offload_prefetch_step=offload_prefetch_step,
            offload_params=offload_params or set(),
359
            enforce_eager=enforce_eager,
360
            enable_return_routed_experts=enable_return_routed_experts,
361
            disable_custom_all_reduce=disable_custom_all_reduce,
362
            hf_token=hf_token,
363
            hf_overrides=hf_overrides,
364
            mm_processor_kwargs=mm_processor_kwargs,
365
            pooler_config=pooler_config,
366
            structured_outputs_config=structured_outputs_instance,
367
            profiler_config=profiler_config_instance,
368
            attention_config=attention_config_instance,
369
            compilation_config=compilation_config_instance,
370
            logits_processors=logits_processors,
371
372
            **kwargs,
        )
373

374
375
        log_non_default_args(engine_args)

376
        self.llm_engine = LLMEngine.from_engine_args(
377
378
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
379
        self.model_config = self.llm_engine.model_config
380
        self.engine_class = type(self.llm_engine)
381

382
        self.request_counter = Counter()
383
        self.default_sampling_params: dict[str, Any] | None = None
384

385
        supported_tasks = self.llm_engine.get_supported_tasks()
386
        self.supported_tasks = supported_tasks
387
388
389
        self.pooling_task = self.model_config.get_pooling_task(supported_tasks)
        if self.pooling_task is not None:
            logger.info("Supported pooling task: %s", self.pooling_task)
390

391
        self.runner_type = self.model_config.runner_type
392
        self.renderer = self.llm_engine.renderer
393
        self.chat_template = load_chat_template(chat_template)
394
        self.io_processor = self.llm_engine.io_processor
395
        self.input_processor = self.llm_engine.input_processor
396
        self.chat_template_config = ChatTemplateConfig(chat_template=self.chat_template)
397
        self.pooling_io_processors = init_pooling_io_processors(
398
399
400
401
402
            supported_tasks=supported_tasks,
            model_config=self.model_config,
            renderer=self.renderer,
            chat_template_config=self.chat_template_config,
        )
403
404
405
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

406
407
408
409
410
    @classmethod
    def from_engine_args(cls, engine_args: EngineArgs) -> "LLM":
        """Create an LLM instance from EngineArgs."""
        return cls(**vars(engine_args))

411
    def get_tokenizer(self) -> TokenizerLike:
412
        return self.llm_engine.get_tokenizer()
413

414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
    def get_world_size(self, include_dp: bool = True) -> int:
        """Get the world size from the parallel config.

        Args:
            include_dp: If True (default), returns the world size including
                data parallelism (TP * PP * DP). If False, returns the world
                size without data parallelism (TP * PP).

        Returns:
            The world size (tensor_parallel_size * pipeline_parallel_size),
            optionally multiplied by data_parallel_size if include_dp is True.
        """
        parallel_config = self.llm_engine.vllm_config.parallel_config
        if include_dp:
            return parallel_config.world_size_across_dp
        return parallel_config.world_size

431
    def reset_mm_cache(self) -> None:
432
        self.renderer.clear_mm_cache()
433
434
        self.llm_engine.reset_mm_cache()

435
    def get_default_sampling_params(self) -> SamplingParams:
436
        if self.default_sampling_params is None:
437
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
438
439
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
440
441
        return SamplingParams()

442
443
    def generate(
        self,
444
445
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
446
        *,
447
        use_tqdm: bool | Callable[..., tqdm] = True,
448
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
449
        priority: list[int] | None = None,
450
        tokenization_kwargs: dict[str, Any] | None = None,
451
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
452
453
        """Generates the completions for the input prompts.

454
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
455
456
457
458
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
459
            prompts: The prompts to the LLM. You may pass a sequence of prompts
460
                for batch inference. See [PromptType][vllm.inputs.PromptType]
461
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
462
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
463
464
465
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
466
                prompts and it is paired one by one with the prompt.
467
468
469
470
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
471
            lora_request: LoRA request to use for generation, if any.
472
473
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
474
475
476
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
477
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
478
479

        Returns:
480
            A list of `RequestOutput` objects containing the
481
482
            generated completions in the same order as the input prompts.
        """
483
        runner_type = self.model_config.runner_type
484
        if runner_type != "generate":
485
486
487
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
488
489
                "generative model."
            )
490

491
        if sampling_params is None:
492
            sampling_params = self.get_default_sampling_params()
493

494
        return self._run_completion(
495
            prompts=prompts,
496
            params=sampling_params,
497
            output_type=RequestOutput,
498
            use_tqdm=use_tqdm,
499
            lora_request=lora_request,
500
            tokenization_kwargs=tokenization_kwargs,
501
502
            priority=priority,
        )
503

504
505
506
507
    def enqueue(
        self,
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
508
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
        priority: list[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[str]:
        """Enqueue prompts for generation without waiting for completion.

        This method adds requests to the engine queue but does not start
        processing them. Use wait_for_completion() to process the queued
        requests and get results.

        Args:
            prompts: The prompts to the LLM. See generate() for details.
            sampling_params: The sampling parameters for text generation.
            lora_request: LoRA request to use for generation, if any.
            priority: The priority of the requests, if any.
            use_tqdm: If True, shows a tqdm progress bar while adding requests.
            tokenization_kwargs: Overrides for `tokenizer.encode`.

        Returns:
            A list of request IDs for the enqueued requests.
        """
530
        runner_type = self.model_config.runner_type
531
532
533
534
535
536
        if runner_type != "generate":
            raise ValueError("LLM.enqueue() is only supported for generative models.")

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

537
538
539
540
541
542
543
        return self._add_completion_requests(
            prompts=prompts,
            params=sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
544
545
        )

546
    @overload
547
548
    def wait_for_completion(
        self,
549
        *,
550
        use_tqdm: bool | Callable[..., tqdm] = True,
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
    ) -> list[RequestOutput | PoolingRequestOutput]: ...

    @overload
    def wait_for_completion(
        self,
        output_type: type[_O] | tuple[type[_O], ...],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[_O]: ...

    def wait_for_completion(
        self,
        output_type: type[Any] | tuple[type[Any], ...] | None = None,
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ) -> list[Any]:
567
568
569
570
571
572
        """Wait for all enqueued requests to complete and return results.

        This method processes all requests currently in the engine queue
        and returns their outputs. Use after enqueue() to get results.

        Args:
573
            output_type: The expected output type, defaults to RequestOutput.
574
575
576
            use_tqdm: If True, shows a tqdm progress bar.

        Returns:
577
            A list of output objects for all completed requests.
578
        """
579
580
581
582
        if output_type is None:
            output_type = (RequestOutput, PoolingRequestOutput)

        return self._run_engine(output_type, use_tqdm=use_tqdm)
583

Cyrus Leung's avatar
Cyrus Leung committed
584
    def _resolve_mm_lora(
585
        self,
586
        prompt: EngineInput,
587
        lora_request: LoRARequest | None,
Cyrus Leung's avatar
Cyrus Leung committed
588
589
590
591
592
593
594
    ) -> LoRARequest | None:
        if prompt["type"] != "multimodal":
            return lora_request

        lora_config = self.llm_engine.vllm_config.lora_config
        default_mm_loras = None if lora_config is None else lora_config.default_mm_loras
        if not default_mm_loras:
595
596
            return lora_request

597
598
        prompt_modalities = prompt["mm_placeholders"].keys()
        intersection = set(prompt_modalities).intersection(default_mm_loras.keys())
599
600
        if not intersection:
            return lora_request
Cyrus Leung's avatar
Cyrus Leung committed
601

602
603
604
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
Cyrus Leung's avatar
Cyrus Leung committed
605
606
607
608
                "Multiple modality specific loras were registered and would be "
                "used by a single prompt consuming several modalities; "
                "currently we only support one lora per request; as such, "
                "lora(s) registered with modalities: %s will be skipped",
609
610
                intersection,
            )
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
626
627
                    "lora_request as we only apply one LoRARequest per prompt"
                )
628
629
630
631
632
633
634
635
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

636
637
    def collective_rpc(
        self,
638
639
        method: str | Callable[..., _R],
        timeout: float | None = None,
640
        args: tuple = (),
641
        kwargs: dict[str, Any] | None = None,
642
    ) -> list[_R]:
643
644
645
646
647
648
649
650
651
652
653
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
654
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
655
656
657
658
659
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
660

661
662
663
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
664
        """
665
666

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
667
668

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
669
        """
670
671
        Run a function directly on the model inside each worker,
        returning the result for each of them.
672
673
674
675
676
677

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
678
        """
679
        return self.llm_engine.apply_model(func)
680

681
682
    def beam_search(
        self,
683
        prompts: list[TokensPrompt | TextPrompt],
684
        params: BeamSearchParams,
685
        lora_request: list[LoRARequest] | LoRARequest | None = None,
686
        use_tqdm: bool = False,
687
        concurrency_limit: int | None = None,
688
    ) -> list[BeamSearchOutput]:
689
690
691
692
693
694
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
695
            params: The beam search parameters.
696
            lora_request: LoRA request to use for generation, if any.
697
            use_tqdm: Whether to use tqdm to display the progress bar.
698
699
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
700
        """
701
702
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
703
704
705
706
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
707
708
        length_penalty = params.length_penalty

709
710
711
        tokenizer = self.renderer.get_tokenizer()
        eos_token_id = tokenizer.eos_token_id
        sort_beams_key = create_sort_beams_key_function(eos_token_id, length_penalty)
712

713
714
        engine_inputs = self._preprocess_cmpl(prompts)
        lora_requests = self._lora_request_to_seq(lora_request, len(engine_inputs))
715

716
717
718
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
719
720
                "Disabling progress bar."
            )
721
722
723
            use_tqdm = False

        if concurrency_limit is None:
724
            concurrency_limit = len(engine_inputs)
725

726
727
728
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
729
        sampling_params = SamplingParams(
730
731
732
733
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
734
        )
735
        instances: list[BeamSearchInstance] = []
736

737
        for lora_req, prompt in zip(lora_requests, engine_inputs):
738
739
740
741
            if prompt["type"] == "embeds":
                raise NotImplementedError(
                    "Embedding prompt not supported for beam search"
                )
742

743
            instances.append(
744
                BeamSearchInstance(
745
                    prompt,
746
747
                    lora_request=lora_req,
                    logprobs=None,
748
749
                ),
            )
750

751
        for prompt_start in range(0, len(instances), concurrency_limit):
752
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
753
754
755

            token_iter = range(max_tokens)
            if use_tqdm:
756
757
758
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
759
760
761
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
762
763
                    "reflect instance-level progress."
                )
764
765
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
766
767
                    sum((instance.beams for instance in instances_batch), [])
                )
768
769
                pos = [0] + list(
                    itertools.accumulate(
770
771
772
                        len(instance.beams) for instance in instances_batch
                    )
                )
773
                instance_start_and_end: list[tuple[int, int]] = list(
774
775
                    zip(pos[:-1], pos[1:])
                )
776
777
778
779
780
781

                if len(all_beams) == 0:
                    break

                # only runs for one step
                # we don't need to use tqdm here
782
                output = self._render_and_run_requests(
783
784
                    prompts=(beam.get_prompt() for beam in all_beams),
                    params=self._params_to_seq(sampling_params, len(all_beams)),
785
                    output_type=RequestOutput,
786
                    lora_requests=[beam.lora_request for beam in all_beams],
787
788
                    use_tqdm=False,
                )
789

790
791
792
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
793
794
795
796
797
798
799
800
801
802
803
804
805
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
806
                                    current_beam.orig_prompt,
807
                                    tokens=current_beam.tokens + [token_id],
808
                                    logprobs=current_beam.logprobs + [logprobs],
809
                                    lora_request=current_beam.lora_request,
810
811
812
813
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                )

814
                                if token_id == eos_token_id and not ignore_eos:
815
816
817
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
818
819
820
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
821
                    instance.beams = sorted_beams[:beam_width]
822
823
824
825

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
826
827
828
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
829
830
831
832
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
833

834
835
836
837
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

838
    def _preprocess_cmpl(
839
        self,
840
        prompts: Sequence[PromptType],
841
        tokenization_kwargs: dict[str, Any] | None = None,
842
    ) -> Sequence[EngineInput]:
843
844
845
846
847
848
849
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
850
            A list of `EngineInput` objects ready to be passed into LLMEngine.
851
        """
852
        renderer = self.renderer
853
854
        model_config = self.model_config

855
856
857
        parsed_prompts = [
            parse_model_prompt(model_config, prompt) for prompt in prompts
        ]
858
859
860
        tok_params = renderer.default_cmpl_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
861

862
        return renderer.render_cmpl(parsed_prompts, tok_params)
863

864
865
866
867
    def _preprocess_cmpl_one(
        self,
        prompt: PromptType,
        tokenization_kwargs: dict[str, Any] | None = None,
868
869
870
    ) -> EngineInput:
        (engine_input,) = self._preprocess_cmpl([prompt], tokenization_kwargs)
        return engine_input
871

872
873
    def _preprocess_chat(
        self,
874
        conversations: Sequence[list[ChatCompletionMessageParam]],
875
        chat_template: str | None = None,
876
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
877
        chat_template_kwargs: dict[str, Any] | None = None,
878
        add_generation_prompt: bool = True,
879
        continue_final_message: bool = False,
880
        tools: list[dict[str, Any]] | None = None,
881
        tokenization_kwargs: dict[str, Any] | None = None,
882
        mm_processor_kwargs: dict[str, Any] | None = None,
883
    ) -> Sequence[EngineInput]:
nunjunj's avatar
nunjunj committed
884
        """
885
886
887
888
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
889
890

        Returns:
891
            A list of `EngineInput` objects ready to be passed into LLMEngine.
nunjunj's avatar
nunjunj committed
892
        """
893
        renderer = self.renderer
894

895
896
897
898
899
900
901
902
903
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
904
                    tokenize=is_mistral_tokenizer(renderer.tokenizer),
905
906
907
                ),
            ),
        )
908
909
910
        tok_params = renderer.default_chat_tok_params.with_kwargs(
            **(tokenization_kwargs or {})
        )
911

912
        _, engine_inputs = renderer.render_chat(
913
914
915
916
917
            conversations,
            chat_params,
            tok_params,
            prompt_extras={"mm_processor_kwargs": mm_processor_kwargs},
        )
918

919
        return engine_inputs
920

921
922
923
924
925
926
927
928
929
930
931
    def _preprocess_chat_one(
        self,
        conversation: list[ChatCompletionMessageParam],
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        chat_template_kwargs: dict[str, Any] | None = None,
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
932
933
    ) -> EngineInput:
        (engine_input,) = self._preprocess_chat(
934
935
936
937
938
939
940
941
942
943
944
            [conversation],
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=chat_template_kwargs,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            tokenization_kwargs=tokenization_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

945
        return engine_input
946

947
948
    def chat(
        self,
949
        messages: list[ChatCompletionMessageParam]
950
951
        | Sequence[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
952
        use_tqdm: bool | Callable[..., tqdm] = True,
953
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
954
        chat_template: str | None = None,
955
956
957
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
958
959
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
960
        tokenization_kwargs: dict[str, Any] | None = None,
961
        mm_processor_kwargs: dict[str, Any] | None = None,
962
963
964
965
966
967
968
969
970
971
972
973
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
974
            messages: A sequence of conversations or a single conversation.
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
1006
1007
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
1008
1009
1010
1011
1012

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
        model_config = self.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
            raise ValueError(
                "LLM.chat() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model."
            )

        if sampling_params is None:
            sampling_params = self.get_default_sampling_params()

1025
        return self._run_chat(
1026
1027
            messages=messages,
            params=sampling_params,
1028
            output_type=RequestOutput,
1029
1030
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1031
1032
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1033
            chat_template_kwargs=chat_template_kwargs,
1034
1035
1036
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1037
            tokenization_kwargs=tokenization_kwargs,
1038
1039
1040
            mm_processor_kwargs=mm_processor_kwargs,
        )

1041
1042
    def encode(
        self,
1043
1044
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1045
        *,
1046
1047
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1048
        pooling_task: PoolingTask | None = None,
1049
        tokenization_kwargs: dict[str, Any] | None = None,
1050
    ) -> list[PoolingRequestOutput]:
1051
1052
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1053

1054
        This class automatically batches the given prompts, considering
1055
1056
1057
1058
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1059
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1060
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1061
                for more details about the format of each prompt.
1062
1063
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1064
1065
1066
1067
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1068
            lora_request: LoRA request to use for generation, if any.
1069
            pooling_task: Override the pooling task to use.
1070
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1071
1072

        Returns:
1073
            A list of `PoolingRequestOutput` objects containing the
1074
            pooled hidden states in the same order as the input prompts.
1075
        """
1076

1077
        self._verify_pooling_task(pooling_task)
1078

1079
        if isinstance(prompts, dict) and "data" in prompts:
1080
1081
1082
1083
1084
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1085
1086
                    "offline inference example for more details."
                )
1087
1088

            # Validate the request data is valid for the loaded plugin
1089
1090
1091
1092
1093
1094
1095
1096
1097
            prompt_data = prompts.get("data")
            if prompt_data is None:
                raise ValueError(
                    "The 'data' field of the prompt is expected to contain "
                    "the prompt data and it cannot be None. "
                    "Refer to the documentation of the IOProcessor "
                    "in use for more details."
                )
            validated_prompt = self.io_processor.parse_data(prompt_data)
1098
1099
1100

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)
1101
            prompts_seq = prompt_to_seq(prompts)
1102

1103
1104
1105
1106
1107
            params_seq: Sequence[PoolingParams] = [
                self.io_processor.merge_pooling_params(param)
                for param in self._params_to_seq(
                    pooling_params,
                    len(prompts_seq),
1108
                )
1109
1110
1111
1112
            ]
            for p in params_seq:
                if p.task is None:
                    p.task = "plugin"
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137

            outputs = self._run_completion(
                prompts=prompts_seq,
                params=params_seq,
                output_type=PoolingRequestOutput,
                use_tqdm=use_tqdm,
                lora_request=lora_request,
                tokenization_kwargs=tokenization_kwargs,
            )

            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(outputs)

            return [
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
                    prompt_token_ids=[],
                    finished=True,
                )
            ]
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        else:
            if pooling_params is None:
                # Use default pooling params.
                pooling_params = PoolingParams()

            prompts_seq = prompt_to_seq(prompts)
            params_seq = self._params_to_seq(pooling_params, len(prompts_seq))

            for param in params_seq:
                if param.task is None:
                    param.task = pooling_task
                elif param.task != pooling_task:
                    msg = (
                        f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                    )
                    raise ValueError(msg)
1154

1155
1156
            if pooling_task in self.pooling_io_processors:
                io_processor = self.pooling_io_processors[pooling_task]
1157
                processor_inputs = io_processor.pre_process_offline(
1158
1159
1160
                    ctx=OfflineInputsContext(
                        prompts=prompts_seq, tokenization_kwargs=tokenization_kwargs
                    )
1161
1162
1163
1164
1165
                )
                seq_lora_requests = self._lora_request_to_seq(
                    lora_request, len(prompts_seq)
                )
                seq_priority = self._priority_to_seq(None, len(prompts))
1166

1167
1168
1169
1170
1171
                self._render_and_add_requests(
                    prompts=processor_inputs,
                    params=params_seq,
                    lora_requests=seq_lora_requests,
                    priorities=seq_priority,
1172
                )
1173

1174
1175
1176
                outputs = self._run_engine(
                    use_tqdm=use_tqdm, output_type=PoolingRequestOutput
                )
1177
1178
1179
                outputs = io_processor.post_process_offline(
                    ctx=OfflineOutputsContext(outputs=outputs)
                )
1180
1181
1182
1183
1184
1185
1186
1187
1188
            else:
                outputs = self._run_completion(
                    prompts=prompts_seq,
                    params=params_seq,
                    output_type=PoolingRequestOutput,
                    use_tqdm=use_tqdm,
                    lora_request=lora_request,
                    tokenization_kwargs=tokenization_kwargs,
                )
1189
        return outputs
1190

1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
    def _verify_pooling_task(self, pooling_task: PoolingTask | None):
        if self.runner_type != "pooling":
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model."
            )

        if pooling_task is None:
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )

        if (
            pooling_task in ("embed", "token_embed")
            and pooling_task not in self.supported_tasks
        ):
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`."
            )

        if (
            pooling_task in ("classify", "token_classify")
            and pooling_task not in self.supported_tasks
        ):
            raise ValueError(
                "Classification API is not supported by this model. "
                "Try converting the model using `--convert classify`."
            )

        # plugin task uses io_processor.parse_request to verify inputs
        if pooling_task != "plugin" and pooling_task != self.pooling_task:
            if pooling_task not in self.supported_tasks:
                raise ValueError(
                    f"Unsupported task: {pooling_task!r} "
                    f"Supported tasks: {self.supported_tasks}"
                )
            else:
                logger.warning_once(
                    "Pooling multitask support is deprecated and will "
                    "be removed in v0.20. When the default pooling task is "
                    "not what you want, you need to manually specify it "
                    'via PoolerConfig(task="%s"). ',
                    pooling_task,
                )

1250
1251
    def embed(
        self,
1252
        prompts: PromptType | Sequence[PromptType],
1253
        *,
1254
1255
1256
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1257
        tokenization_kwargs: dict[str, Any] | None = None,
1258
    ) -> list[EmbeddingRequestOutput]:
1259
1260
1261
1262
1263
1264
1265
1266
1267
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1268
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1269
                for more details about the format of each prompt.
1270
1271
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1272
1273
1274
1275
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1276
            lora_request: LoRA request to use for generation, if any.
1277
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1278
1279

        Returns:
1280
            A list of `EmbeddingRequestOutput` objects containing the
1281
1282
1283
            embedding vectors in the same order as the input prompts.
        """

1284
1285
1286
1287
1288
1289
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1290
            tokenization_kwargs=tokenization_kwargs,
1291
        )
1292
1293
1294
1295
1296

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1297
        prompts: PromptType | Sequence[PromptType],
1298
        *,
1299
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1300
        use_tqdm: bool | Callable[..., tqdm] = True,
1301
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1302
        tokenization_kwargs: dict[str, Any] | None = None,
1303
    ) -> list[ClassificationRequestOutput]:
1304
1305
1306
1307
1308
1309
1310
1311
1312
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1313
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1314
                for more details about the format of each prompt.
1315
1316
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1317
1318
1319
1320
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1321
            lora_request: LoRA request to use for generation, if any.
1322
1323
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1324
        Returns:
1325
            A list of `ClassificationRequestOutput` objects containing the
1326
1327
1328
            embedding vectors in the same order as the input prompts.
        """

1329
1330
1331
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1332
            pooling_params=pooling_params,
1333
1334
            lora_request=lora_request,
            pooling_task="classify",
1335
            tokenization_kwargs=tokenization_kwargs,
1336
        )
1337
1338
1339

        return [ClassificationRequestOutput.from_base(item) for item in items]

1340
1341
    def reward(
        self,
1342
        prompts: PromptType | Sequence[PromptType],
1343
1344
        /,
        *,
1345
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1346
1347
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1348
        tokenization_kwargs: dict[str, Any] | None = None,
1349
1350
1351
1352
1353
1354
1355
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1356
                for more details about the format of each prompt.
1357
1358
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1359
1360
1361
1362
1363
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1364
1365
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1366
1367
1368
1369
1370
1371
1372
1373
1374
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """
        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
1375
            pooling_task="token_classify",
1376
            tokenization_kwargs=tokenization_kwargs,
1377
1378
        )

1379
1380
    def score(
        self,
1381
1382
        data_1: ScoreInput | list[ScoreInput],
        data_2: ScoreInput | list[ScoreInput],
1383
        /,
1384
        *,
1385
1386
1387
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1388
        tokenization_kwargs: dict[str, Any] | None = None,
1389
        chat_template: str | None = None,
1390
    ) -> list[ScoringRequestOutput]:
1391
1392
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1393

1394
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1395
1396
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1397
        The input pairs are used to build a list of prompts for the
1398
1399
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1400
1401
1402
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1403
        appropriate multi-modal models. For multi-modal inputs, ensure the
1404
        prompt structure matches the model's expected input format.
1405
1406

        Args:
1407
1408
1409
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1410
                the `data_2` list.
1411
            data_2: The data to pair with the query to form the input to
1412
                the LLM. Can be text or multi-modal data. See [PromptType]
1413
                [vllm.inputs.PromptType] for more details about the format of
1414
                each prompt.
1415
1416
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1417
1418
1419
1420
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1421
            lora_request: LoRA request to use for generation, if any.
1422
1423
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1424
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1425
        Returns:
1426
            A list of `ScoringRequestOutput` objects containing the
1427
1428
            generated scores in the same order as the input prompts.
        """
1429

1430
        if self.runner_type != "pooling":
1431
1432
1433
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1434
1435
                "pooling model."
            )
1436

1437
        score_type = self.model_config.score_type
1438
1439
1440
        if (
            score_type == "cross-encoder"
            and getattr(self.model_config.hf_config, "num_labels", 0) != 1
1441
        ):
1442
            raise ValueError("Scoring API is only enabled for num_labels == 1.")
1443

1444
1445
        if score_type is None or score_type not in self.pooling_io_processors:
            raise ValueError("This model does not support the Scoring API.")
1446

1447
1448
        io_processor = self.pooling_io_processors[score_type]
        assert isinstance(io_processor, ScoringIOProcessor)
1449

1450
1451
1452
        pooling_task = io_processor.pooling_task
        scoring_data = io_processor.valid_inputs(data_1, data_2)
        offset = len(scoring_data.data_1)
1453

1454
1455
1456
1457
1458
1459
        ctx = OfflineInputsContext(
            prompts=scoring_data,
            pooling_params=pooling_params,
            tokenization_kwargs=tokenization_kwargs,
            chat_template=chat_template,
            offset=offset,
1460
        )
1461

1462
1463
1464
1465
        processor_inputs = io_processor.pre_process_offline(ctx)

        seq_lora_requests = self._lora_request_to_seq(
            lora_request, len(processor_inputs)
1466
        )
1467

1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
        if ctx.pooling_params is None:
            ctx.pooling_params = PoolingParams()
        params_seq = self._params_to_seq(ctx.pooling_params, len(processor_inputs))

        for param in params_seq:
            if param.task is None:
                param.task = pooling_task
            elif param.task != pooling_task:
                msg = f"You cannot overwrite {param.task=!r} with {pooling_task=!r}!"
                raise ValueError(msg)

        seq_priority = self._priority_to_seq(None, len(processor_inputs))

        self._render_and_add_requests(
            prompts=processor_inputs,
            params=params_seq,
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm, output_type=PoolingRequestOutput)
        outputs = io_processor.post_process_offline(
            ctx=OfflineOutputsContext(outputs=outputs, offset=offset),
        )

        return [ScoringRequestOutput.from_base(item) for item in outputs]
1494

1495
1496
1497
1498
1499
1500
1501
1502
1503
    def start_profile(self, profile_prefix: str | None = None) -> None:
        """Start profiling with optional custom trace prefix.

        Args:
            profile_prefix: Optional prefix for the trace file names. If provided,
                           trace files will be named as "<prefix>_dp<X>_pp<Y>_tp<Z>".
                           If not provided, default naming will be used.
        """
        self.llm_engine.start_profile(profile_prefix)
1504
1505
1506
1507

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1508
1509
1510
1511
1512
1513
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1514

1515
    def sleep(self, level: int = 1, mode: PauseMode = "abort"):
1516
1517
1518
1519
1520
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1521
        Args:
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
            level: The sleep level.
                - Level 0: Pause scheduling but continue accepting requests.
                           Requests are queued but not processed.
                - Level 1: Offload model weights to CPU, discard KV cache.
                           The content of kv cache is forgotten. Good for
                           sleeping and waking up the engine to run the same
                           model again. Please make sure there's enough CPU
                           memory to store the model weights.
                - Level 2: Discard all GPU memory (weights + KV cache).
                           Good for sleeping and waking up the engine to run
                           a different model or update the model, where
                           previous model weights are not needed. It reduces
                           CPU memory pressure.
1535
1536
            mode: How to handle any existing requests, can be "abort", "wait",
                or "keep".
1537
        """
1538
        self.llm_engine.sleep(level=level, mode=mode)
1539

1540
    def wake_up(self, tags: list[str] | None = None):
1541
        """
1542
1543
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1544

1545
        Args:
1546
1547
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1548
1549
1550
1551
                `("weights", "kv_cache", "scheduling")`. If None, all memory
                is reallocated. wake_up should be called with all tags
                (or None) before the engine is used again.
                Use tags=["scheduling"] to resume from level 0 sleep.
1552
1553
        """
        self.llm_engine.wake_up(tags)
1554

1555
1556
1557
1558
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1559
            A `MetricSnapshot` instance capturing the current state
1560
1561
1562
1563
1564
1565
1566
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1567
    def _params_to_seq(
1568
        self,
1569
        params: _P | Sequence[_P],
1570
        num_requests: int,
1571
    ) -> Sequence[_P]:
1572
1573
1574
1575
        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
1576
                    f"and params ({len(params)}) must be the same."
1577
1578
                )

1579
            return params
1580

1581
1582
1583
1584
1585
1586
1587
        return [params] * num_requests

    def _lora_request_to_seq(
        self,
        lora_request: LoRARequest | None | Sequence[LoRARequest | None],
        num_requests: int,
    ) -> Sequence[LoRARequest | None]:
1588
1589
1590
1591
1592
1593
1594
        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

1595
1596
1597
            return lora_request

        return [lora_request] * num_requests
1598

1599
1600
1601
1602
1603
    def _priority_to_seq(
        self,
        priority: list[int] | None,
        num_requests: int,
    ) -> Sequence[int]:
1604
1605
1606
1607
1608
1609
1610
        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )

1611
1612
1613
1614
            return priority

        return [0] * num_requests

1615
    def _add_completion_requests(
1616
1617
1618
1619
1620
1621
1622
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1623
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1624
1625
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
1626
    ) -> list[str]:
1627
1628
        seq_prompts = prompt_to_seq(prompts)
        seq_params = self._params_to_seq(params, len(seq_prompts))
1629
1630
1631
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_prompts))
        seq_priority = self._priority_to_seq(priority, len(prompts))

1632
        return self._render_and_add_requests(
1633
            prompts=(
1634
1635
1636
1637
1638
                self._preprocess_cmpl_one(prompt, tokenization_kwargs)
                for prompt in maybe_tqdm(
                    seq_prompts,
                    use_tqdm=use_tqdm,
                    desc="Rendering prompts",
1639
                )
1640
            ),
1641
            params=seq_params,
1642
1643
            lora_requests=seq_lora_requests,
            priorities=seq_priority,
1644
1645
        )

1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
    def _run_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
        output_type: type[_O],
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
    ):
        self._add_completion_requests(
            prompts=prompts,
            params=params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            priority=priority,
            tokenization_kwargs=tokenization_kwargs,
        )
        return self._run_engine(use_tqdm=use_tqdm, output_type=output_type)

1669
1670
1671
1672
1673
1674
1675
    def _run_chat(
        self,
        messages: list[ChatCompletionMessageParam]
        | Sequence[list[ChatCompletionMessageParam]],
        params: SamplingParams
        | PoolingParams
        | Sequence[SamplingParams | PoolingParams],
1676
        output_type: type[_O],
1677
1678
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1679
        lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
1680
1681
1682
1683
1684
1685
1686
1687
1688
        chat_template: str | None = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
        tokenization_kwargs: dict[str, Any] | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
    ):
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
        seq_convs = conversation_to_seq(messages)
        seq_params = self._params_to_seq(params, len(seq_convs))
        seq_lora_requests = self._lora_request_to_seq(lora_request, len(seq_convs))

        return self._render_and_run_requests(
            prompts=(
                self._preprocess_chat_one(
                    conversation,
                    chat_template=chat_template,
                    chat_template_content_format=chat_template_content_format,
                    chat_template_kwargs=chat_template_kwargs,
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
1703
                    tokenization_kwargs=tokenization_kwargs,
1704
1705
                    mm_processor_kwargs=mm_processor_kwargs,
                )
1706
1707
1708
1709
                for conversation in maybe_tqdm(
                    seq_convs,
                    use_tqdm=use_tqdm,
                    desc="Rendering conversations",
1710
1711
1712
                )
            ),
            params=seq_params,
1713
            output_type=output_type,
1714
1715
            lora_requests=seq_lora_requests,
            use_tqdm=use_tqdm,
1716
1717
        )

1718
1719
    def _render_and_run_requests(
        self,
1720
        prompts: Iterable[EngineInput],
1721
        params: Sequence[SamplingParams | PoolingParams],
1722
        output_type: type[_O],
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
        *,
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
    ):
        if isinstance(prompts, (list, tuple)):
            logger.warning_once(
                "Rendering all prompts before adding them to the engine "
                "is less efficient than performing both on the same prompt "
                "before processing the next prompt. You should instead pass "
                "a generator that renders one prompt per iteration, as that allows "
                "engine execution to begin for the first prompt while processing "
                "the next prompt."
            )

        self._render_and_add_requests(
            prompts=prompts,
1740
            params=params,
1741
1742
            lora_requests=lora_requests,
            priorities=priorities,
1743
1744
        )

1745
        return self._run_engine(output_type, use_tqdm=use_tqdm)
1746

1747
    def _render_and_add_requests(
1748
        self,
1749
        prompts: Iterable[EngineInput],
1750
        params: Sequence[SamplingParams | PoolingParams],
1751
        *,
1752
1753
        lora_requests: Sequence[LoRARequest | None] | None = None,
        priorities: Sequence[int] | None = None,
1754
    ) -> list[str]:
1755
        added_request_ids: list[str] = []
1756

1757
        try:
1758
            for i, prompt in enumerate(prompts):
1759
1760
                request_id = self._add_request(
                    prompt,
1761
                    params[i],
Cyrus Leung's avatar
Cyrus Leung committed
1762
1763
1764
1765
                    lora_request=self._resolve_mm_lora(
                        prompt,
                        None if lora_requests is None else lora_requests[i],
                    ),
1766
                    priority=0 if priorities is None else priorities[i],
1767
1768
1769
1770
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1771
                self.llm_engine.abort_request(added_request_ids, internal=True)
1772
            raise e
1773

1774
1775
        return added_request_ids

1776
    def _add_request(
nunjunj's avatar
nunjunj committed
1777
        self,
1778
        prompt: EngineInput,
1779
1780
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1781
        priority: int = 0,
1782
    ) -> str:
1783
1784
1785
1786
        if isinstance(params, SamplingParams):
            # We only care about the final output
            params.output_kind = RequestOutputKind.FINAL_ONLY

1787
        request_id = str(next(self.request_counter))
1788

1789
        return self.llm_engine.add_request(
1790
            request_id,
1791
            prompt,
1792
1793
            params,
            lora_request=lora_request,
1794
            priority=priority,
nunjunj's avatar
nunjunj committed
1795
        )
1796

1797
    def _run_engine(
1798
        self,
1799
        output_type: type[_O] | tuple[type[_O], ...],
1800
1801
        *,
        use_tqdm: bool | Callable[..., tqdm] = True,
1802
    ) -> list[_O]:
1803
1804
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1805
            num_requests = self.llm_engine.get_num_unfinished_requests()
1806
1807
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1808
1809
1810
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1811
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1812
            )
1813

Zhuohan Li's avatar
Zhuohan Li committed
1814
        # Run the engine.
1815
        outputs: list[_O] = []
1816
1817
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1818
1819
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1820
            for output in step_outputs:
1821
                assert isinstance(output, output_type)
1822
                if output.finished:
1823
                    outputs.append(output)  # type: ignore[arg-type]
1824
                    if use_tqdm:
1825
1826
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1827
                            n = len(output.outputs)
1828
                            assert output.prompt_token_ids is not None
1829
                            total_in_toks += len(output.prompt_token_ids) * n
1830
1831
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1832
1833
1834
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1835
1836
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1837
1838
                                f"output: {out_spd:.2f} toks/s"
                            )
1839
                            pbar.update(n)
1840
1841
                        else:
                            pbar.update(1)
1842
1843
                        if pbar.n == num_requests:
                            pbar.refresh()
1844

1845
1846
        if use_tqdm:
            pbar.close()
1847
1848
1849
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1850
        return sorted(outputs, key=lambda x: int(x.request_id))
1851

1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
    def init_weight_transfer_engine(
        self, request: WeightTransferInitRequest | dict
    ) -> None:
        """
        Initialize weight transfer for RL training.

        Args:
            request: Weight transfer initialization request with backend-specific info
        """
        init_info_dict = (
            request["init_info"] if isinstance(request, dict) else request.init_info
        )

        self.llm_engine.collective_rpc(
            "init_weight_transfer_engine", kwargs={"init_info": init_info_dict}
        )

    def update_weights(self, request: WeightTransferUpdateRequest | dict) -> None:
        """
        Update the weights of the model.

        Args:
            request: Weight update request with backend-specific update info
        """
        update_info_dict = (
            request["update_info"] if isinstance(request, dict) else request.update_info
        )

        self.llm_engine.collective_rpc(
            "update_weights", kwargs={"update_info": update_info_dict}
        )

1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr