llm.py 78.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Callable, Sequence
7
from typing import TYPE_CHECKING, Any, TypeAlias, cast
8

9
import cloudpickle
10
import torch.nn as nn
11
from pydantic import ValidationError
12
from tqdm.auto import tqdm
13
from typing_extensions import TypeVar
14

15
16
17
18
19
20
21
from vllm.beam_search import (
    BeamSearchInstance,
    BeamSearchOutput,
    BeamSearchSequence,
    create_sort_beams_key_function,
)
from vllm.config import (
22
    AttentionConfig,
23
    CompilationConfig,
24
    PoolerConfig,
25
    ProfilerConfig,
26
27
28
    StructuredOutputsConfig,
    is_init_field,
)
29
from vllm.config.compilation import CompilationMode
30
from vllm.config.model import (
31
32
    ConvertOption,
    HfOverrides,
33
    ModelDType,
34
    RunnerOption,
35
    TokenizerMode,
36
)
37
from vllm.engine.arg_utils import EngineArgs
38
39
40
41
from vllm.entrypoints.chat_utils import (
    ChatCompletionMessageParam,
    ChatTemplateContentFormatOption,
)
42
from vllm.entrypoints.pooling.score.utils import (
43
    ScoreData,
44
45
46
    ScoreMultiModalParam,
    _cosine_similarity,
    compress_token_type_ids,
47
    compute_maxsim_score,
48
    get_score_prompt,
49
    validate_score_input,
50
)
51
from vllm.entrypoints.utils import log_non_default_args
52
53
from vllm.inputs import (
    DataPrompt,
54
55
    EmbedsPrompt,
    ExplicitEncoderDecoderPrompt,
56
57
58
59
60
    PromptType,
    SingletonPrompt,
    TextPrompt,
    TokensPrompt,
)
61
from vllm.inputs.parse import get_prompt_components, is_explicit_encoder_decoder_prompt
62
from vllm.logger import init_logger
63
from vllm.lora.request import LoRARequest
64
from vllm.model_executor.layers.quantization import QuantizationMethods
65
66
67
68
69
70
71
from vllm.outputs import (
    ClassificationRequestOutput,
    EmbeddingRequestOutput,
    PoolingRequestOutput,
    RequestOutput,
    ScoringRequestOutput,
)
72
from vllm.platforms import current_platform
73
from vllm.pooling_params import PoolingParams
74
from vllm.renderers import ChatParams, TokenizeParams, merge_kwargs
75
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
76
from vllm.tasks import PoolingTask
77
78
from vllm.tokenizers import TokenizerLike
from vllm.tokenizers.mistral import MistralTokenizer
yhu422's avatar
yhu422 committed
79
from vllm.usage.usage_lib import UsageContext
80
from vllm.utils.collection_utils import as_iter, is_list_of
81
from vllm.utils.counter import Counter
82
from vllm.v1.engine.llm_engine import LLMEngine
83
from vllm.v1.sample.logits_processor import LogitsProcessor
84

85
86
87
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

88
89
logger = init_logger(__name__)

90
91
_R = TypeVar("_R", default=Any)

92
93
94
EnginePrompt: TypeAlias = TextPrompt | TokensPrompt | EmbedsPrompt
EngineEncDecPrompt: TypeAlias = ExplicitEncoderDecoderPrompt[EnginePrompt, EnginePrompt]

95
96

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
97
98
99
100
101
102
103
104
105
106
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
107
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
108
109
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
110
111
112
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
113
114
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
115
116
117
118
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
119
        allowed_media_domains: If set, only media URLs that belong to this
120
            domain can be used for multi-modal inputs.
Woosuk Kwon's avatar
Woosuk Kwon committed
121
122
123
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
125
126
            the `dtype` attribute of the Transformers model's config. However,
            if the `dtype` in the config is `float32`, we will use `float16` instead.
127
        quantization: The method used to quantize the model weights. Currently,
128
            we support "awq", "gptq", and "fp8" (experimental).
129
130
131
132
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
133
134
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
135
136
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
137
138
139
140
141
142
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
143
144
145
146
147
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
148
            compared with using gpu_memory_utilization. Note that
149
150
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
151
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
152
153
154
155
156
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
157
158
159
160
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
161
162
163
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
164
        enable_return_routed_experts: Whether to return routed experts.
165
166
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
167
        hf_token: The token to use as HTTP bearer authorization for remote files
168
            . If `True`, will use the token generated when running
169
            `huggingface-cli login` (stored in `~/.huggingface`).
170
171
172
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
173
174
175
176
177
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
178
179
        pooler_config: Initialize non-default pooling config for the pooling model,
            e.g., `PoolerConfig(seq_pooling_type="MEAN", use_activation=False)`.
180
        compilation_config: Either an integer or a dictionary. If it is an
181
            integer, it is used as the mode of compilation optimization. If it
182
            is a dictionary, it can specify the full compilation configuration.
183
184
185
186
        attention_config: Configuration for attention mechanisms. Can be a
            dictionary or an AttentionConfig instance. If a dictionary, it will
            be converted to an AttentionConfig. Allows specifying the attention
            backend and other attention-related settings.
187
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
188

189
190
    Note:
        This class is intended to be used for offline inference. For online
191
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
192
    """
193
194
195
196

    def __init__(
        self,
        model: str,
197
        *,
198
199
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
200
        tokenizer: str | None = None,
201
        tokenizer_mode: TokenizerMode | str = "auto",
202
        skip_tokenizer_init: bool = False,
203
        trust_remote_code: bool = False,
204
        allowed_local_media_path: str = "",
205
        allowed_media_domains: list[str] | None = None,
206
        tensor_parallel_size: int = 1,
207
        dtype: ModelDType = "auto",
208
209
210
        quantization: QuantizationMethods | None = None,
        revision: str | None = None,
        tokenizer_revision: str | None = None,
211
        seed: int = 0,
212
        gpu_memory_utilization: float = 0.9,
213
        swap_space: float = 4,
214
        cpu_offload_gb: float = 0,
215
        enforce_eager: bool = False,
216
        enable_return_routed_experts: bool = False,
217
        disable_custom_all_reduce: bool = False,
218
219
220
221
222
223
224
        hf_token: bool | str | None = None,
        hf_overrides: HfOverrides | None = None,
        mm_processor_kwargs: dict[str, Any] | None = None,
        pooler_config: PoolerConfig | None = None,
        structured_outputs_config: dict[str, Any]
        | StructuredOutputsConfig
        | None = None,
225
        profiler_config: dict[str, Any] | ProfilerConfig | None = None,
226
        attention_config: dict[str, Any] | AttentionConfig | None = None,
227
228
229
        kv_cache_memory_bytes: int | None = None,
        compilation_config: int | dict[str, Any] | CompilationConfig | None = None,
        logits_processors: list[str | type[LogitsProcessor]] | None = None,
230
        **kwargs: Any,
231
    ) -> None:
232
        """LLM constructor."""
233

234
235
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
236

237
238
239
240
241
242
243
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

244
        if "kv_transfer_config" in kwargs and isinstance(
245
246
            kwargs["kv_transfer_config"], dict
        ):
247
            from vllm.config.kv_transfer import KVTransferConfig
248

249
250
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
251
                kwargs["kv_transfer_config"] = KVTransferConfig(**raw_config_dict)
252
253
254
255
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
256
257
258
                    raw_config_dict,
                    e,
                )
259
260
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
261
                raise ValueError(f"Invalid 'kv_transfer_config' provided: {e}") from e
262

263
264
265
        if hf_overrides is None:
            hf_overrides = {}

266
267
268
269
270
271
272
        def _make_config(value: Any, cls: type[_R]) -> _R:
            """Convert dict/None/instance to a config instance."""
            if value is None:
                return cls()
            if isinstance(value, dict):
                return cls(**{k: v for k, v in value.items() if is_init_field(cls, k)})  # type: ignore[arg-type]
            return value
273

274
275
276
277
        if isinstance(compilation_config, int):
            compilation_config_instance = CompilationConfig(
                mode=CompilationMode(compilation_config)
            )
278
        else:
279
280
281
            compilation_config_instance = _make_config(
                compilation_config, CompilationConfig
            )
282

283
284
285
286
287
        structured_outputs_instance = _make_config(
            structured_outputs_config, StructuredOutputsConfig
        )
        profiler_config_instance = _make_config(profiler_config, ProfilerConfig)
        attention_config_instance = _make_config(attention_config, AttentionConfig)
288

289
        # warn about single-process data parallel usage.
290
291
        _dp_size = int(kwargs.get("data_parallel_size", 1))
        _distributed_executor_backend = kwargs.get("distributed_executor_backend")
292
293
294
295
296
        if (
            _dp_size > 1
            and not _distributed_executor_backend == "external_launcher"
            and not current_platform.is_tpu()
        ):
297
            raise ValueError(
298
                f"LLM(data_parallel_size={_dp_size}) is not supported for single-"
299
300
301
302
303
                "process usage and may hang. Please use "
                "the explicit multi-process data-parallel example at "
                "'examples/offline_inference/data_parallel.py'."
            )

Zhuohan Li's avatar
Zhuohan Li committed
304
        engine_args = EngineArgs(
305
            model=model,
306
307
            runner=runner,
            convert=convert,
308
            tokenizer=tokenizer,
309
            tokenizer_mode=tokenizer_mode,
310
            skip_tokenizer_init=skip_tokenizer_init,
311
            trust_remote_code=trust_remote_code,
312
            allowed_local_media_path=allowed_local_media_path,
313
            allowed_media_domains=allowed_media_domains,
314
315
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
316
            quantization=quantization,
317
            revision=revision,
318
            tokenizer_revision=tokenizer_revision,
319
320
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
321
            kv_cache_memory_bytes=kv_cache_memory_bytes,
322
            swap_space=swap_space,
323
            cpu_offload_gb=cpu_offload_gb,
324
            enforce_eager=enforce_eager,
325
            enable_return_routed_experts=enable_return_routed_experts,
326
            disable_custom_all_reduce=disable_custom_all_reduce,
327
            hf_token=hf_token,
328
            hf_overrides=hf_overrides,
329
            mm_processor_kwargs=mm_processor_kwargs,
330
            pooler_config=pooler_config,
331
            structured_outputs_config=structured_outputs_instance,
332
            profiler_config=profiler_config_instance,
333
            attention_config=attention_config_instance,
334
            compilation_config=compilation_config_instance,
335
            logits_processors=logits_processors,
336
337
            **kwargs,
        )
338

339
340
        log_non_default_args(engine_args)

341
        self.llm_engine = LLMEngine.from_engine_args(
342
343
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS
        )
344
        self.engine_class = type(self.llm_engine)
345

346
        self.request_counter = Counter()
347
        self.default_sampling_params: dict[str, Any] | None = None
348

349
350
        supported_tasks = self.llm_engine.get_supported_tasks()
        logger.info("Supported tasks: %s", supported_tasks)
351
352
        self.supported_tasks = supported_tasks

353
        self.model_config = self.llm_engine.model_config
354
        self.input_processor = self.llm_engine.input_processor
355
        self.io_processor = self.llm_engine.io_processor
356

357
358
359
        # Cache for __repr__ to avoid repeated collective_rpc calls
        self._cached_repr: str | None = None

360
    def get_tokenizer(self) -> TokenizerLike:
361
        return self.llm_engine.get_tokenizer()
362

363
    def reset_mm_cache(self) -> None:
364
        self.input_processor.clear_mm_cache()
365
366
        self.llm_engine.reset_mm_cache()

367
    def get_default_sampling_params(self) -> SamplingParams:
368
        if self.default_sampling_params is None:
369
            self.default_sampling_params = self.model_config.get_diff_sampling_param()
370
371
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
372
373
        return SamplingParams()

374
375
    def generate(
        self,
376
377
        prompts: PromptType | Sequence[PromptType],
        sampling_params: SamplingParams | Sequence[SamplingParams] | None = None,
378
        *,
379
380
381
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
        priority: list[int] | None = None,
382
        tokenization_kwargs: dict[str, Any] | None = None,
383
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
384
385
        """Generates the completions for the input prompts.

386
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
387
388
389
390
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
391
            prompts: The prompts to the LLM. You may pass a sequence of prompts
392
                for batch inference. See [PromptType][vllm.inputs.PromptType]
393
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
394
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
395
396
397
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
398
                prompts and it is paired one by one with the prompt.
399
400
401
402
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
403
            lora_request: LoRA request to use for generation, if any.
404
405
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
406
407
408
                If provided, must be a list of integers matching the length
                of `prompts`, where each priority value corresponds to the prompt
                at the same index.
409
            tokenization_kwargs: Overrides for `tokenizer.encode`.
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411

        Returns:
412
            A list of `RequestOutput` objects containing the
413
414
            generated completions in the same order as the input prompts.
        """
415
        model_config = self.model_config
416
417
        runner_type = model_config.runner_type
        if runner_type != "generate":
418
419
420
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
421
422
                "generative model."
            )
423

424
        if sampling_params is None:
425
            sampling_params = self.get_default_sampling_params()
426

427
        self._validate_and_add_requests(
428
            prompts=prompts,
429
            params=sampling_params,
430
            use_tqdm=use_tqdm,
431
432
            lora_request=self._get_modality_specific_lora_reqs(prompts, lora_request),
            tokenization_kwargs=tokenization_kwargs,
433
434
            priority=priority,
        )
435

436
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
437
        return self.engine_class.validate_outputs(outputs, RequestOutput)
438

439
    def _get_modality_specific_lora_reqs(
440
        self,
441
442
        prompts: PromptType | Sequence[PromptType],
        lora_request: list[LoRARequest] | LoRARequest | None,
443
    ):
444
445
446
447
448
449
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
450
451
        if (
            lora_config is None
452
            or not self.model_config.is_multimodal_model
453
454
            or (lora_config and lora_config.default_mm_loras is None)
        ):
455
456
            return lora_request

457
        if not isinstance(prompts, Sequence) or isinstance(prompts, str):
458
            prompts = [prompts]
459

460
461
462
463
464
        optional_loras = (
            [lora_request] * len(prompts)
            if not isinstance(lora_request, Sequence)
            else lora_request
        )
465
466
467

        return [
            self._resolve_single_prompt_mm_lora(
468
                prompt,
469
470
                opt_lora_req,
                lora_config.default_mm_loras,
471
472
            )
            for prompt, opt_lora_req in zip(prompts, optional_loras)
473
474
        ]

475
476
477
    def _resolve_single_prompt_mm_lora(
        self,
        prompt: PromptType,
478
479
        lora_request: LoRARequest | None,
        default_mm_loras: dict[str, str] | None,
480
481
482
483
    ):
        if (
            not default_mm_loras
            or not isinstance(prompt, dict)
484
            or not (mm_data := prompt.get("multi_modal_data") or {})
485
        ):
486
487
            return lora_request

488
489
490
        intersection = set(
            mm_data.keys()  # type: ignore
        ).intersection(default_mm_loras.keys())
491
492
493
494
495
496
497
498
499
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
500
501
502
                " will be skipped",
                intersection,
            )
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
518
519
                    "lora_request as we only apply one LoRARequest per prompt"
                )
520
521
522
523
524
525
526
527
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

528
529
    def collective_rpc(
        self,
530
531
        method: str | Callable[..., _R],
        timeout: float | None = None,
532
        args: tuple = (),
533
        kwargs: dict[str, Any] | None = None,
534
    ) -> list[_R]:
535
536
537
538
539
540
541
542
543
544
545
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
546
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
547
548
549
550
551
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
552

553
554
555
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
556
        """
557
558

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
559
560

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
561
        """
562
563
        Run a function directly on the model inside each worker,
        returning the result for each of them.
564
565
566
567
568
569

        !!! warning
            To reduce the overhead of data transfer, avoid returning large
            arrays or tensors from this method. If you must return them,
            make sure you move them to CPU first to avoid taking up additional
            VRAM!
570
        """
571
        return self.llm_engine.apply_model(func)
572

573
574
    def _get_beam_search_lora_requests(
        self,
575
576
577
        lora_request: list[LoRARequest] | LoRARequest | None,
        prompts: list[TokensPrompt | TextPrompt],
    ) -> list[LoRARequest | None]:
578
        """Get the optional lora request corresponding to each prompt."""
579
        if isinstance(lora_request, Sequence) and len(lora_request) != len(prompts):
580
            raise ValueError(
581
582
                "Lora request list should be the same length as the prompts"
            )
583
584
585
586
587
588

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

589
590
    def beam_search(
        self,
591
        prompts: list[TokensPrompt | TextPrompt],
592
        params: BeamSearchParams,
593
        lora_request: list[LoRARequest] | LoRARequest | None = None,
594
        use_tqdm: bool = False,
595
        concurrency_limit: int | None = None,
596
    ) -> list[BeamSearchOutput]:
597
598
599
600
601
602
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
603
            params: The beam search parameters.
604
            lora_request: LoRA request to use for generation, if any.
605
            use_tqdm: Whether to use tqdm to display the progress bar.
606
607
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
608
        """
609
610
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
611
612
613
614
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
615
616
        length_penalty = params.length_penalty

617
        lora_requests = self._get_beam_search_lora_requests(lora_request, prompts)
618

619
620
621
622
623
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
624

625
626
627
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
628
629
                "Disabling progress bar."
            )
630
631
632
633
634
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

635
636
        def create_tokens_prompt_from_beam(beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {"prompt_token_ids": beam.tokens}
637
638
639
640
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
641
                token_prompt_kwargs["mm_processor_kwargs"] = beam.mm_processor_kwargs
642
            return TokensPrompt(**token_prompt_kwargs)
643

644
645
646
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
647
        beam_search_params = SamplingParams(
648
649
650
651
            logprobs=2 * beam_width,
            max_tokens=1,
            temperature=temperature,
            skip_clone=True,  # Internal beam search, safe to skip clone
652
        )
653
        instances: list[BeamSearchInstance] = []
654

655
        for lora_req, prompt in zip(lora_requests, prompts):
656
657
658
659
660
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
661
                mm_kwargs["mm_processor_kwargs"] = prompt["mm_processor_kwargs"]
662

663
664
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
665
666
667
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
668

669
            instances.append(
670
671
672
673
674
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
675
676
                ),
            )
677

678
        for prompt_start in range(0, len(prompts), concurrency_limit):
679
            instances_batch = instances[prompt_start : prompt_start + concurrency_limit]
680
681
682

            token_iter = range(max_tokens)
            if use_tqdm:
683
684
685
                token_iter = tqdm(
                    token_iter, desc="Beam search", unit="token", unit_scale=False
                )
686
687
688
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
689
690
                    "reflect instance-level progress."
                )
691
692
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
693
694
                    sum((instance.beams for instance in instances_batch), [])
                )
695
696
                pos = [0] + list(
                    itertools.accumulate(
697
698
699
                        len(instance.beams) for instance in instances_batch
                    )
                )
700
                instance_start_and_end: list[tuple[int, int]] = list(
701
702
                    zip(pos[:-1], pos[1:])
                )
703
704
705
706
707
708

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
709
710
711
712
713
                    *[
                        (create_tokens_prompt_from_beam(beam), beam.lora_request)
                        for beam in all_beams
                    ]
                )
714
715
716

                # only runs for one step
                # we don't need to use tqdm here
717
718
719
720
721
722
                output = self.generate(
                    prompts_batch,
                    sampling_params=beam_search_params,
                    use_tqdm=False,
                    lora_request=lora_req_batch,
                )
723

724
725
726
                for (start, end), instance in zip(
                    instance_start_and_end, instances_batch
                ):
727
728
729
730
731
732
733
734
735
736
737
738
739
740
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
741
                                    logprobs=current_beam.logprobs + [logprobs],
742
                                    lora_request=current_beam.lora_request,
743
744
745
746
747
748
749
750
751
752
                                    cum_logprob=current_beam.cum_logprob
                                    + logprob_obj.logprob,
                                    multi_modal_data=current_beam.multi_modal_data,
                                    mm_processor_kwargs=current_beam.mm_processor_kwargs,
                                )

                                if (
                                    token_id == tokenizer.eos_token_id
                                    and not ignore_eos
                                ):
753
754
755
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
756
757
758
                    sorted_beams = sorted(
                        instance_new_beams, key=sort_beams_key, reverse=True
                    )
759
                    instance.beams = sorted_beams[:beam_width]
760
761
762
763

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
764
765
766
            sorted_completed = sorted(
                instance.completed, key=sort_beams_key, reverse=True
            )
767
768
769
770
771
772
773
774
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

775
776
777
778
779
780
781
782
783
784
785
786
787
788
    def _get_cmpl_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            add_special_tokens=not model_config.is_encoder_decoder,
        ).with_kwargs(tokenization_kwargs)

    def _normalize_prompts(
nunjunj's avatar
nunjunj committed
789
        self,
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
        prompts: PromptType | Sequence[PromptType],
    ) -> list[EnginePrompt | EngineEncDecPrompt]:
        if isinstance(prompts, str):
            prompts = TextPrompt(prompt=prompts)

        return prompts if isinstance(prompts, Sequence) else [prompts]  # type: ignore[return-value]

    def _preprocess_cmpl_singleton(
        self,
        prompt: SingletonPrompt,
        tok_params: TokenizeParams,
        *,
        tokenize: bool,
    ) -> EnginePrompt:
        renderer = self.llm_engine.renderer

        if not isinstance(prompt, dict):
            prompt = renderer.render_completion(prompt)

        return renderer.tokenize_prompt(prompt, tok_params) if tokenize else prompt

    def _preprocess_cmpl_enc_dec(
        self,
        prompt: ExplicitEncoderDecoderPrompt,
        tok_params: TokenizeParams,
    ) -> EngineEncDecPrompt:
        enc_prompt = prompt["encoder_prompt"]
        dec_prompt = prompt["decoder_prompt"]

        return EngineEncDecPrompt(
            encoder_prompt=self._preprocess_cmpl_singleton(
                enc_prompt,
                tok_params,
                # TODO: Move multi-modal processor into tokenization
                tokenize=not self.model_config.is_multimodal_model,
            ),
            decoder_prompt=(
                None
                if dec_prompt is None
                else self._preprocess_cmpl_singleton(
                    dec_prompt,
                    tok_params,
                    # TODO: Move multi-modal processor into tokenization
                    tokenize=not self.model_config.is_multimodal_model,
                )
            ),
        )

    def _preprocess_completion(
        self,
        prompts: PromptType | Sequence[PromptType],
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[EnginePrompt | EngineEncDecPrompt]:
        """
        Convert prompt inputs from LLM APIs (other than [LLM.chat][]) into
        a format that can be passed to `_add_request`.

        Refer to [LLM.generate][] for a complete description of the arguments.

        Returns:
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
        """
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

        engine_prompts = list[EnginePrompt | EngineEncDecPrompt]()
        for prompt in self._normalize_prompts(prompts):
            if is_explicit_encoder_decoder_prompt(prompt):
                engine_prompts.append(self._preprocess_cmpl_enc_dec(prompt, tok_params))
            else:
                # Some MM models have non-default `add_special_tokens`
                # TODO: Move multi-modal processor into tokenization
                engine_prompts.append(
                    self._preprocess_cmpl_singleton(
                        prompt,
                        tok_params,
                        tokenize=not self.model_config.is_multimodal_model,
                    )
                )

        return engine_prompts

    def _normalize_conversations(
        self,
        conversations: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
    ) -> list[list[ChatCompletionMessageParam]]:
        return conversations if is_list_of(conversations, list) else [conversations]  # type: ignore[list-item,return-value]

    def _get_chat_tok_params(self, tokenization_kwargs: dict[str, Any] | None):
        model_config = self.model_config
        encoder_config = model_config.encoder_config or {}

        return TokenizeParams(
            max_total_tokens=model_config.max_model_len,
            do_lower_case=encoder_config.get("do_lower_case", False),
            add_special_tokens=False,
        ).with_kwargs(tokenization_kwargs)

    def _preprocess_chat(
        self,
        conversations: list[ChatCompletionMessageParam]
892
893
        | list[list[ChatCompletionMessageParam]],
        chat_template: str | None = None,
894
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
895
        chat_template_kwargs: dict[str, Any] | None = None,
896
        add_generation_prompt: bool = True,
897
        continue_final_message: bool = False,
898
        tools: list[dict[str, Any]] | None = None,
899
        tokenization_kwargs: dict[str, Any] | None = None,
900
        mm_processor_kwargs: dict[str, Any] | None = None,
901
    ) -> list[EnginePrompt]:
nunjunj's avatar
nunjunj committed
902
        """
903
904
905
906
        Convert a list of conversations into prompts so that they can then
        be used as input for other LLM APIs.

        Refer to [LLM.chat][] for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
907
908

        Returns:
909
910
            A list of `TokensPrompts` objects containing the tokenized prompt
            after chat template interpolation, and the raw multi-modal inputs.
nunjunj's avatar
nunjunj committed
911
        """
912
        renderer = self.llm_engine.renderer
913

914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
        chat_params = ChatParams(
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            chat_template_kwargs=merge_kwargs(
                chat_template_kwargs,
                dict(
                    add_generation_prompt=add_generation_prompt,
                    continue_final_message=continue_final_message,
                    tools=tools,
                    tokenize=isinstance(renderer.tokenizer, MistralTokenizer),
                ),
            ),
        )
        tok_params = self._get_chat_tok_params(tokenization_kwargs)

        engine_prompts = list[EnginePrompt]()
        for conversation in self._normalize_conversations(conversations):
            _, in_prompt = renderer.render_messages(conversation, chat_params)
932
            if mm_processor_kwargs is not None:
933
                in_prompt["mm_processor_kwargs"] = mm_processor_kwargs
934

935
            engine_prompts.append(renderer.tokenize_prompt(in_prompt, tok_params))
936

937
        return engine_prompts
938
939
940

    def chat(
        self,
941
942
943
944
945
946
        messages: list[ChatCompletionMessageParam]
        | list[list[ChatCompletionMessageParam]],
        sampling_params: SamplingParams | list[SamplingParams] | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: LoRARequest | None = None,
        chat_template: str | None = None,
947
948
949
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
950
951
        tools: list[dict[str, Any]] | None = None,
        chat_template_kwargs: dict[str, Any] | None = None,
952
        tokenization_kwargs: dict[str, Any] | None = None,
953
        mm_processor_kwargs: dict[str, Any] | None = None,
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
998
999
            tokenization_kwargs: Overrides for `tokenizer.encode`.
            mm_processor_kwargs: Overrides for `processor.__call__`.
1000
1001
1002
1003
1004

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """
1005
1006
        prompts = self._preprocess_chat(
            messages,
1007
1008
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
1009
            chat_template_kwargs=chat_template_kwargs,
1010
1011
1012
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
1013
            tokenization_kwargs=tokenization_kwargs,
1014
1015
1016
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
1017
        return self.generate(
1018
            prompts,
1019
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
1020
1021
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1022
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1023
1024
        )

1025
1026
    def encode(
        self,
1027
1028
        prompts: PromptType | Sequence[PromptType] | DataPrompt,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1029
        *,
1030
1031
1032
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1033
        pooling_task: PoolingTask | None = None,
1034
        tokenization_kwargs: dict[str, Any] | None = None,
1035
    ) -> list[PoolingRequestOutput]:
1036
1037
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1038

1039
        This class automatically batches the given prompts, considering
1040
1041
1042
1043
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1044
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1045
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1046
                for more details about the format of each prompt.
1047
1048
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1049
1050
1051
1052
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1053
            lora_request: LoRA request to use for generation, if any.
1054
            pooling_task: Override the pooling task to use.
1055
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1056
1057

        Returns:
1058
            A list of `PoolingRequestOutput` objects containing the
1059
            pooled hidden states in the same order as the input prompts.
1060
        """
1061

1062
        if pooling_task is None:
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
            raise ValueError(
                "pooling_task required for `LLM.encode`\n"
                "Please use one of the more specific methods or set the "
                "pooling_task when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                'or `pooling_task="embed"`.\n'
                "  - For classification logits, use `LLM.classify(...)` "
                'or `pooling_task="classify"`.\n'
                "  - For similarity scores, use `LLM.score(...)`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                'or `pooling_task="token_classify"`\n'
                "  - For token classification, "
                'use `pooling_task="token_classify"`\n'
                '  - For multi-vector retrieval, use `pooling_task="token_embed"`'
            )
1078

1079
        model_config = self.model_config
1080
        runner_type = model_config.runner_type
1081
        if runner_type != "pooling":
1082
1083
1084
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1085
1086
                "pooling model."
            )
1087

1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
        if truncate_prompt_tokens is not None:
            warnings.warn(
                "The `truncate_prompt_tokens` parameter in `LLM.encode()` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1102
1103
1104
1105
1106
1107
1108
1109
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
1110
1111
                    "offline inference example for more details."
                )
1112
1113
1114
1115
1116
1117
1118

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
        if io_processor_prompt:
            assert self.io_processor is not None
            if is_list_of(pooling_params, PoolingParams):
                validated_pooling_params: list[PoolingParams] = []
                for param in as_iter(pooling_params):
                    validated_pooling_params.append(
                        self.io_processor.validate_or_generate_params(param)
                    )
                pooling_params = validated_pooling_params
            else:
                assert not isinstance(pooling_params, Sequence)
                pooling_params = self.io_processor.validate_or_generate_params(
                    pooling_params
                )
1133
1134
1135
1136

        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1137
1138
1139
1140
1141
1142
1143

        if pooling_task not in self.supported_tasks:
            raise ValueError(f"pooling_task must be one of {self.supported_tasks}.")

        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)

1144
        self._validate_and_add_requests(
1145
            prompts=prompts,
1146
            params=pooling_params,
1147
            use_tqdm=use_tqdm,
1148
            lora_request=lora_request,
1149
            tokenization_kwargs=tokenization_kwargs,
1150
1151
        )

1152
        outputs = self._run_engine(use_tqdm=use_tqdm)
1153
1154

        model_outputs = self.engine_class.validate_outputs(
1155
1156
            outputs, PoolingRequestOutput
        )
1157
1158
1159
1160
1161

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
1162
1163
                model_output=model_outputs
            )
1164
1165

            return [
1166
1167
1168
                PoolingRequestOutput[Any](
                    request_id="",
                    outputs=processed_outputs,
1169
1170
1171
                    num_cached_tokens=getattr(
                        processed_outputs, "num_cached_tokens", 0
                    ),
1172
1173
1174
                    prompt_token_ids=[],
                    finished=True,
                )
1175
1176
1177
            ]
        else:
            return model_outputs
1178

1179
1180
    def embed(
        self,
1181
        prompts: PromptType | Sequence[PromptType],
1182
        *,
1183
1184
1185
1186
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1187
        tokenization_kwargs: dict[str, Any] | None = None,
1188
    ) -> list[EmbeddingRequestOutput]:
1189
1190
1191
1192
1193
1194
1195
1196
1197
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1198
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1199
                for more details about the format of each prompt.
1200
1201
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1202
1203
1204
1205
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1206
            lora_request: LoRA request to use for generation, if any.
1207
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1208
1209

        Returns:
1210
            A list of `EmbeddingRequestOutput` objects containing the
1211
1212
            embedding vectors in the same order as the input prompts.
        """
1213
        if "embed" not in self.supported_tasks:
1214
1215
            raise ValueError(
                "Embedding API is not supported by this model. "
1216
1217
                "Try converting the model using `--convert embed`."
            )
1218

1219
1220
1221
1222
1223
1224
        if truncate_prompt_tokens is not None:
            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=truncate_prompt_tokens),
            )

1225
1226
1227
1228
1229
1230
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
1231
            tokenization_kwargs=tokenization_kwargs,
1232
        )
1233
1234
1235
1236
1237

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
1238
        prompts: PromptType | Sequence[PromptType],
1239
        *,
1240
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1241
        use_tqdm: bool | Callable[..., tqdm] = True,
1242
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1243
        tokenization_kwargs: dict[str, Any] | None = None,
1244
    ) -> list[ClassificationRequestOutput]:
1245
1246
1247
1248
1249
1250
1251
1252
1253
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1254
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1255
                for more details about the format of each prompt.
1256
1257
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1258
1259
1260
1261
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1262
            lora_request: LoRA request to use for generation, if any.
1263
1264
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1265
        Returns:
1266
            A list of `ClassificationRequestOutput` objects containing the
1267
1268
            embedding vectors in the same order as the input prompts.
        """
1269
        if "classify" not in self.supported_tasks:
1270
            raise ValueError(
1271
                "Classification API is not supported by this model. "
1272
1273
                "Try converting the model using `--convert classify`."
            )
1274

1275
1276
1277
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1278
            pooling_params=pooling_params,
1279
1280
            lora_request=lora_request,
            pooling_task="classify",
1281
            tokenization_kwargs=tokenization_kwargs,
1282
        )
1283
1284
1285

        return [ClassificationRequestOutput.from_base(item) for item in items]

1286
1287
    def reward(
        self,
1288
        prompts: PromptType | Sequence[PromptType],
1289
1290
        /,
        *,
1291
        pooling_params: PoolingParams | Sequence[PoolingParams] | None = None,
1292
1293
1294
        truncate_prompt_tokens: int | None = None,
        use_tqdm: bool | Callable[..., tqdm] = True,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1295
        tokenization_kwargs: dict[str, Any] | None = None,
1296
1297
1298
1299
1300
1301
1302
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1303
                for more details about the format of each prompt.
1304
1305
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1306
1307
1308
1309
1310
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
1311
1312
            tokenization_kwargs: Overrides for `tokenizer.encode`.

1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
1324
            pooling_task="token_classify",
1325
            tokenization_kwargs=tokenization_kwargs,
1326
1327
        )

1328
1329
    def _embedding_score(
        self,
1330
1331
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1332
1333
1334
1335
1336
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
1337
    ) -> list[ScoringRequestOutput]:
1338
1339
        tokenizer = self.get_tokenizer()

1340
1341
1342
1343
1344
1345
1346
1347
        input_texts: list[str] = []
        for text in data_1 + data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Embedding scores currently do not support multimodal input."
                )
            input_texts.append(text)

1348
        encoded_output = self.encode(
1349
            input_texts,
1350
1351
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1352
            pooling_params=pooling_params,
1353
            pooling_task="embed",
1354
            tokenization_kwargs=tokenization_kwargs,
1355
        )
1356

1357
1358
        encoded_output_1 = encoded_output[0 : len(data_1)]
        encoded_output_2 = encoded_output[len(data_1) :]
1359
1360
1361
1362

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1363
        scores = _cosine_similarity(
1364
1365
1366
            tokenizer=tokenizer,
            embed_1=encoded_output_1,
            embed_2=encoded_output_2,
1367
        )
1368

1369
        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
1370
1371
        return [ScoringRequestOutput.from_base(item) for item in items]

1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
    def _late_interaction_score(
        self,
        data_1: list[ScoreData],
        data_2: list[ScoreData],
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
    ) -> list[ScoringRequestOutput]:
        """
        Late interaction scoring (ColBERT MaxSim).

        Encodes queries and documents into per-token embeddings, then computes
        MaxSim: sum over query tokens of max similarity to any document token.
        """
        from vllm.outputs import PoolingOutput

        tokenizer = self.get_tokenizer()

        # Extract text from ScoreData
        text_1: list[str] = []
        for text in data_1:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_1.append(text)

        text_2: list[str] = []
        for text in data_2:
            if not isinstance(text, str):
                raise NotImplementedError(
                    "Late interaction scores currently do not support multimodal input."
                )
            text_2.append(text)

        encoded_output: list[PoolingRequestOutput] = self.encode(
            text_1 + text_2,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            pooling_task="token_embed",
            tokenization_kwargs=tokenization_kwargs,
        )

        encoded_output_1: list[PoolingRequestOutput] = encoded_output[0 : len(text_1)]
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[len(text_1) :]

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

        # Compute MaxSim scores
        scores: list[PoolingRequestOutput] = []
        padding: list[int] = []
        if (pad_token_id := tokenizer.pad_token_id) is not None:
            padding = [pad_token_id]

        for emb_1, emb_2 in zip(encoded_output_1, encoded_output_2):
            # emb_1.outputs.data: [query_len, dim]
            # emb_2.outputs.data: [doc_len, dim]
            q_emb = emb_1.outputs.data
            d_emb = emb_2.outputs.data

            maxsim_score = compute_maxsim_score(q_emb, d_emb)

            tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids

            scores.append(
                PoolingRequestOutput(
                    request_id=f"{emb_1.request_id}_{emb_2.request_id}",
                    outputs=PoolingOutput(data=maxsim_score),
                    prompt_token_ids=tokens,
                    num_cached_tokens=emb_1.num_cached_tokens + emb_2.num_cached_tokens,
                    finished=True,
                )
            )

        items = self.engine_class.validate_outputs(scores, PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

1453
1454
    def _cross_encoding_score(
        self,
1455
1456
        data_1: list[ScoreData],
        data_2: list[ScoreData],
1457
1458
1459
1460
1461
1462
        *,
        use_tqdm: bool | Callable[..., tqdm],
        pooling_params: PoolingParams | None,
        lora_request: list[LoRARequest] | LoRARequest | None,
        tokenization_kwargs: dict[str, Any],
        score_template: str | None,
1463
    ) -> list[ScoringRequestOutput]:
1464
        model_config = self.model_config
1465
        tokenizer = self.get_tokenizer()
1466
1467

        if isinstance(tokenizer, MistralTokenizer):
1468
            raise ValueError("Score API is not supported for Mistral tokenizer")
1469

1470
1471
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1472

1473
1474
1475
1476
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        pooling_params.verify("score", model_config)
1477
        pooling_params_list = list[PoolingParams]()
1478

1479
        prompts = list[PromptType]()
1480

1481
1482
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1483
1484
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
1485
                model_config=model_config,
1486
1487
1488
1489
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
1490
                score_template=score_template,
1491
1492
            )

1493
            if token_type_ids := engine_prompt.pop("token_type_ids", None):
1494
1495
1496
1497
1498
1499
1500
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1501
            prompts.append(engine_prompt)
1502
1503

        self._validate_and_add_requests(
1504
            prompts=prompts,
1505
            params=pooling_params_list,
1506
            use_tqdm=use_tqdm,
1507
1508
1509
1510
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1511
        items = self.engine_class.validate_outputs(outputs, PoolingRequestOutput)
1512
1513
1514

        return [ScoringRequestOutput.from_base(item) for item in items]

1515
1516
    def score(
        self,
1517
1518
1519
1520
1521
1522
1523
1524
        data_1: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
        data_2: SingletonPrompt
        | Sequence[SingletonPrompt]
        | ScoreMultiModalParam
        | list[ScoreMultiModalParam],
1525
        /,
1526
        *,
1527
1528
1529
        use_tqdm: bool | Callable[..., tqdm] = True,
        pooling_params: PoolingParams | None = None,
        lora_request: list[LoRARequest] | LoRARequest | None = None,
1530
        tokenization_kwargs: dict[str, Any] | None = None,
1531
        chat_template: str | None = None,
1532
    ) -> list[ScoringRequestOutput]:
1533
1534
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1535

1536
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1537
1538
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1539
        The input pairs are used to build a list of prompts for the
1540
1541
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1542
1543
1544
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1545
        appropriate multi-modal models. For multi-modal inputs, ensure the
1546
        prompt structure matches the model's expected input format.
1547
1548

        Args:
1549
1550
1551
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1552
                the `data_2` list.
1553
            data_2: The data to pair with the query to form the input to
1554
                the LLM. Can be text or multi-modal data. See [PromptType]
1555
                [vllm.inputs.PromptType] for more details about the format of
1556
                each prompt.
1557
1558
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1559
1560
1561
1562
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1563
            lora_request: LoRA request to use for generation, if any.
1564
1565
            chat_template: The chat template to use for the scoring. If None, we
                use the model's default chat template.
1566
            tokenization_kwargs: Overrides for `tokenizer.encode`.
1567
        Returns:
1568
            A list of `ScoringRequestOutput` objects containing the
1569
1570
            generated scores in the same order as the input prompts.
        """
1571
        model_config = self.model_config
1572

1573
        runner_type = model_config.runner_type
1574
        if runner_type != "pooling":
1575
1576
1577
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
1578
1579
                "pooling model."
            )
1580

1581
        supported_tasks = self.supported_tasks
1582
1583
1584
1585
1586
        # Late interaction models (e.g., ColBERT) use token_embed for scoring
        is_late_interaction = model_config.is_late_interaction
        if not is_late_interaction and all(
            t not in supported_tasks for t in ("embed", "classify")
        ):
1587
1588
1589
1590
1591
            raise ValueError(
                "Score API is not supported by this model. "
                "Try converting the model using "
                "`--convert embed` or `--convert classify`."
            )
1592

1593
1594
1595
1596
        if (
            model_config.is_cross_encoder
            and getattr(model_config.hf_config, "num_labels", 0) != 1
        ):
1597
            raise ValueError("Score API is only enabled for num_labels == 1.")
1598

1599
1600
1601
1602
1603
        if not model_config.is_cross_encoder and chat_template is not None:
            raise ValueError(
                "chat_template is only supported for cross-encoder models."
            )

1604
1605
        is_multimodal_model = model_config.is_multimodal_model
        architecture = model_config.architecture
1606

1607
1608
1609
1610
1611
1612
        score_data_1, score_data_2 = validate_score_input(
            data_1,  # type: ignore[arg-type]
            data_2,  # type: ignore[arg-type]
            is_multimodal_model=is_multimodal_model,
            architecture=architecture,
        )
1613

1614
1615
1616
        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)
        encode_kwargs = tok_params.get_encode_kwargs()

1617
        if model_config.is_cross_encoder:
1618
            return self._cross_encoding_score(
1619
1620
                score_data_1,
                score_data_2,
1621
1622
1623
1624
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1625
                score_template=chat_template,
1626
            )
1627
1628
1629
1630
1631
1632
1633
1634
1635
        elif is_late_interaction:
            return self._late_interaction_score(
                score_data_1,
                score_data_2,
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
            )
1636
        else:
1637
            return self._embedding_score(
1638
1639
                score_data_1,
                score_data_2,
1640
1641
1642
1643
                use_tqdm=use_tqdm,
                pooling_params=pooling_params,
                lora_request=lora_request,
                tokenization_kwargs=encode_kwargs,
1644
            )
1645

1646
1647
1648
1649
1650
1651
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1652
1653
1654
1655
1656
1657
    def reset_prefix_cache(
        self, reset_running_requests: bool = False, reset_connector: bool = False
    ) -> bool:
        return self.llm_engine.reset_prefix_cache(
            reset_running_requests, reset_connector
        )
1658

1659
1660
1661
1662
1663
1664
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1665
        Args:
1666
1667
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1668
                is forgotten. Level 1 sleep is good for sleeping and waking
1669
1670
1671
1672
1673
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1674
                sleep is good for sleeping and waking up the engine to run a
1675
                different model or update the model, where previous model
1676
                weights are not needed. It reduces CPU memory pressure.
1677
        """
1678
        self.reset_prefix_cache()
1679
1680
        self.llm_engine.sleep(level=level)

1681
    def wake_up(self, tags: list[str] | None = None):
1682
        """
1683
1684
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1685

1686
        Args:
1687
1688
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1689
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1690
                wake_up should be called with all tags (or None) before the
1691
1692
1693
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1694

1695
1696
1697
1698
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
1699
            A `MetricSnapshot` instance capturing the current state
1700
1701
1702
1703
1704
1705
1706
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        return self.llm_engine.get_metrics()

1707
1708
    def _validate_and_add_requests(
        self,
1709
        prompts: PromptType | Sequence[PromptType],
1710
1711
1712
1713
        params: SamplingParams
        | Sequence[SamplingParams]
        | PoolingParams
        | Sequence[PoolingParams],
1714
        *,
1715
        use_tqdm: bool | Callable[..., tqdm] = True,
1716
        lora_request: Sequence[LoRARequest | None] | LoRARequest | None,
1717
        tokenization_kwargs: dict[str, Any] | None = None,
1718
        priority: list[int] | None = None,
1719
    ) -> None:
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
        in_prompts = self._normalize_prompts(prompts)
        num_requests = len(in_prompts)

        if isinstance(params, Sequence):
            if len(params) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({params}) "
                    f"and lora_request ({len(params)}) must be the same."
                )

            engine_params = params
        else:
            engine_params = [params] * num_requests

        if isinstance(lora_request, Sequence):
            if len(lora_request) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and lora_request ({len(lora_request)}) must be the same."
                )

            engine_lora_requests: Sequence[LoRARequest | None] = lora_request
        else:
            engine_lora_requests = [lora_request] * num_requests

        if priority is not None:
            if len(priority) != num_requests:
                raise ValueError(
                    f"The lengths of prompts ({num_requests}) "
                    f"and priority ({len(priority)}) must be the same."
                )
        else:
            priority = [0] * num_requests

        if any(param.truncate_prompt_tokens is not None for param in engine_params):
            # TODO: Remove this after deprecating `param.truncate_prompt_tokens`
            # Then, move the code from the `else` block to the top and let
            # `self._preprocess_completion` handle prompt normalization
            engine_prompts = [
                engine_prompt
                for in_prompt, param in zip(in_prompts, engine_params)
                for engine_prompt in self._preprocess_completion(
                    [in_prompt],
                    tokenization_kwargs=merge_kwargs(
                        tokenization_kwargs,
                        dict(truncate_prompt_tokens=param.truncate_prompt_tokens),
                    ),
                )
            ]
        else:
            engine_prompts = self._preprocess_completion(
                in_prompts,
                tokenization_kwargs=tokenization_kwargs,
1773
            )
1774

1775
        for sp in engine_params:
1776
1777
1778
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1779

Zhuohan Li's avatar
Zhuohan Li committed
1780
        # Add requests to the engine.
1781
        it = engine_prompts
1782
        if use_tqdm:
1783
1784
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1785

1786
        added_request_ids: list[str] = []
1787

1788
1789
1790
1791
        try:
            for i, prompt in enumerate(it):
                request_id = self._add_request(
                    prompt,
1792
1793
                    engine_params[i],
                    lora_request=engine_lora_requests[i],
1794
                    tokenization_kwargs=tokenization_kwargs,
1795
                    priority=priority[i],
1796
1797
1798
1799
                )
                added_request_ids.append(request_id)
        except Exception as e:
            if added_request_ids:
1800
                self.llm_engine.abort_request(added_request_ids, internal=True)
1801
            raise e
1802

1803
    def _add_request(
nunjunj's avatar
nunjunj committed
1804
        self,
1805
        prompt: PromptType,
1806
1807
        params: SamplingParams | PoolingParams,
        lora_request: LoRARequest | None = None,
1808
        tokenization_kwargs: dict[str, Any] | None = None,
1809
        priority: int = 0,
1810
    ) -> str:
1811
        prompt_text, _, _ = get_prompt_components(prompt)
1812
        request_id = str(next(self.request_counter))
1813

1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
        if params.truncate_prompt_tokens is not None:
            params_type = type(params).__name__
            warnings.warn(
                f"The `truncate_prompt_tokens` parameter in `{params_type}` "
                "is deprecated and will be removed in v0.16. "
                "Please pass it via `tokenization_kwargs` instead.",
                DeprecationWarning,
                stacklevel=2,
            )

            tokenization_kwargs = merge_kwargs(
                tokenization_kwargs,
                dict(truncate_prompt_tokens=params.truncate_prompt_tokens),
            )

        tok_params = self._get_cmpl_tok_params(tokenization_kwargs)

        tokenization_kwargs = tok_params.get_encode_kwargs()
        engine_request = self.input_processor.process_inputs(
1833
            request_id,
1834
            prompt,
1835
1836
            params,
            lora_request=lora_request,
1837
            tokenization_kwargs=tokenization_kwargs,
1838
            priority=priority,
1839
1840
1841
1842
1843
1844
1845
        )

        self.llm_engine.add_request(
            request_id,
            engine_request,
            params,
            lora_request=lora_request,
1846
            tokenization_kwargs=tokenization_kwargs,
1847
            priority=priority,
1848
            prompt_text=prompt_text,
nunjunj's avatar
nunjunj committed
1849
        )
1850
        return engine_request.request_id
1851

1852
    def _run_engine(
1853
1854
        self, *, use_tqdm: bool | Callable[..., tqdm] = True
    ) -> list[RequestOutput | PoolingRequestOutput]:
1855
1856
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1857
            num_requests = self.llm_engine.get_num_unfinished_requests()
1858
1859
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1860
1861
1862
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1863
                postfix=(f"est. speed input: {0:.2f} toks/s, output: {0:.2f} toks/s"),
1864
            )
1865

Zhuohan Li's avatar
Zhuohan Li committed
1866
        # Run the engine.
1867
        outputs: list[RequestOutput | PoolingRequestOutput] = []
1868
1869
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1870
1871
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1872
            for output in step_outputs:
1873
                if output.finished:
1874
1875
                    outputs.append(output)
                    if use_tqdm:
1876
1877
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1878
                            n = len(output.outputs)
1879
                            assert output.prompt_token_ids is not None
1880
                            total_in_toks += len(output.prompt_token_ids) * n
1881
1882
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1883
1884
1885
                                len(stp.token_ids) for stp in output.outputs
                            )
                            out_spd = total_out_toks / pbar.format_dict["elapsed"]
1886
1887
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
1888
1889
                                f"output: {out_spd:.2f} toks/s"
                            )
1890
                            pbar.update(n)
1891
1892
                        else:
                            pbar.update(1)
1893
1894
                        if pbar.n == num_requests:
                            pbar.refresh()
1895

1896
1897
        if use_tqdm:
            pbar.close()
1898
1899
1900
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1901
        return sorted(outputs, key=lambda x: int(x.request_id))
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914

    def __repr__(self) -> str:
        """Return a transformers-style hierarchical view of the model."""
        # Cache the result to avoid repeated collective_rpc calls
        if self._cached_repr is None:
            results = self.llm_engine.collective_rpc("get_model_inspection")
            # In distributed settings, we get results from all workers
            # Just return the first one (they should all be the same)
            if results:
                self._cached_repr = results[0]
            else:
                self._cached_repr = f"LLM(model={self.model_config.model!r})"
        return self._cached_repr