llm.py 69.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
from collections.abc import Sequence
6
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
7

8
import cloudpickle
9
import torch.nn as nn
10
from pydantic import ValidationError
11
from tqdm.auto import tqdm
12
from typing_extensions import TypeVar
13

14
import vllm.envs as envs
15
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
16
17
                              BeamSearchSequence,
                              create_sort_beams_key_function)
18
19
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
20
21
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
Joe Runde's avatar
Joe Runde committed
22
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
23
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
24
                                         ChatTemplateContentFormatOption,
25
26
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
27
28
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
29
30
# yapf conflicts with isort for this block
# yapf: disable
31
32
33
34
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
35
                                          compress_token_type_ids,
36
                                          get_score_prompt)
37
# yapf: enable
38
39
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
40
41
from vllm.inputs import (DataPrompt, PromptType, SingletonPrompt, TextPrompt,
                         TokensPrompt)
42
from vllm.logger import init_logger
43
from vllm.lora.request import LoRARequest
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
46
47
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
48
from vllm.plugins.io_processors import get_io_processor
49
from vllm.pooling_params import PoolingParams
50
51
from vllm.sampling_params import (BeamSearchParams, RequestOutputKind,
                                  SamplingParams)
52
from vllm.tasks import PoolingTask
53
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
54
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
55
from vllm.usage.usage_lib import UsageContext
56
from vllm.utils import Counter, Device, as_iter, is_list_of
57
from vllm.v1.sample.logits_processor import LogitsProcessor
58

59
60
61
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

62
63
logger = init_logger(__name__)

64
65
_R = TypeVar("_R", default=Any)

66
67

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
71
72
73
74
75
76
77
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
78
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
79
80
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
81
82
83
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
84
85
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
86
87
88
89
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
93
94
95
96
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
97
        quantization: The method used to quantize the model weights. Currently,
98
            we support "awq", "gptq", and "fp8" (experimental).
99
100
101
102
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
103
104
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
105
106
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
107
108
109
110
111
112
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
113
114
115
116
117
118
119
120
        kv_cache_memory_bytes: Size of KV Cache per GPU in bytes. By default,
            this is set to None and vllm can automatically infer the kv cache
            size based on gpu_memory_utilization. However, users may want to
            manually specify the kv cache memory size. kv_cache_memory_bytes
            allows more fine-grain control of how much memory gets used when
            compared with using gpu_memory_memory_utilization. Note that
            kv_cache_memory_bytes (when not-None) ignores
            gpu_memory_utilization
121
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
122
123
124
125
126
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
127
128
129
130
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
131
132
133
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
134
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
135
            When a sequence has context length larger than this, we fall back
136
137
138
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
139
140
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
141
142
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
143
        hf_token: The token to use as HTTP bearer authorization for remote files
144
            . If `True`, will use the token generated when running
145
            `huggingface-cli login` (stored in `~/.huggingface`).
146
147
148
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
149
150
151
152
153
154
155
156
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
157
158
159
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
160
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
161

162
163
    Note:
        This class is intended to be used for offline inference. For online
164
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
165
    """
166
167
168
169

    def __init__(
        self,
        model: str,
170
        *,
171
172
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
173
        tokenizer: Optional[str] = None,
174
        tokenizer_mode: TokenizerMode = "auto",
175
        skip_tokenizer_init: bool = False,
176
        trust_remote_code: bool = False,
177
        allowed_local_media_path: str = "",
178
        tensor_parallel_size: int = 1,
179
180
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
181
        revision: Optional[str] = None,
182
        tokenizer_revision: Optional[str] = None,
183
        seed: Optional[int] = None,
184
        gpu_memory_utilization: float = 0.9,
185
        swap_space: float = 4,
186
        cpu_offload_gb: float = 0,
187
        enforce_eager: bool = False,
188
        max_seq_len_to_capture: int = 8192,
189
        disable_custom_all_reduce: bool = False,
190
        disable_async_output_proc: bool = False,
191
        hf_token: Optional[Union[bool, str]] = None,
192
        hf_overrides: Optional[HfOverrides] = None,
193
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
194
        override_pooler_config: Optional[PoolerConfig] = None,
195
        kv_cache_memory_bytes: Optional[int] = None,
196
197
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
198
199
        logits_processors: Optional[list[Union[str,
                                               type[LogitsProcessor]]]] = None,
200
        **kwargs: Any,
201
    ) -> None:
202
        """LLM constructor."""
203

204
205
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
206

207
208
209
210
211
212
213
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

214
215
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
216
            from vllm.config.kv_transfer import KVTransferConfig
217
218
219
220
221
222
223
224
225
226
227
228
229
230
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

231
232
233
        if hf_overrides is None:
            hf_overrides = {}

234
        if compilation_config is not None:
235
236
237
238
239
240
241
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
242
243
            else:
                compilation_config_instance = compilation_config
244
        else:
245
            compilation_config_instance = CompilationConfig()
246

Zhuohan Li's avatar
Zhuohan Li committed
247
        engine_args = EngineArgs(
248
            model=model,
249
250
            runner=runner,
            convert=convert,
251
            tokenizer=tokenizer,
252
            tokenizer_mode=tokenizer_mode,
253
            skip_tokenizer_init=skip_tokenizer_init,
254
            trust_remote_code=trust_remote_code,
255
            allowed_local_media_path=allowed_local_media_path,
256
257
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
258
            quantization=quantization,
259
            revision=revision,
260
            tokenizer_revision=tokenizer_revision,
261
262
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
263
            kv_cache_memory_bytes=kv_cache_memory_bytes,
264
            swap_space=swap_space,
265
            cpu_offload_gb=cpu_offload_gb,
266
            enforce_eager=enforce_eager,
267
            max_seq_len_to_capture=max_seq_len_to_capture,
268
            disable_custom_all_reduce=disable_custom_all_reduce,
269
            disable_async_output_proc=disable_async_output_proc,
270
            hf_token=hf_token,
271
            hf_overrides=hf_overrides,
272
            mm_processor_kwargs=mm_processor_kwargs,
273
            override_pooler_config=override_pooler_config,
274
            compilation_config=compilation_config_instance,
275
            logits_processors=logits_processors,
276
277
            **kwargs,
        )
278

279
280
        log_non_default_args(engine_args)

281
282
283
284
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
285

286
        self.request_counter = Counter()
287
        self.default_sampling_params: Union[dict[str, Any], None] = None
288

289
290
291
292
293
294
295
296
297
298
        if envs.VLLM_USE_V1:
            supported_tasks = self.llm_engine \
                .get_supported_tasks()  # type: ignore
        else:
            supported_tasks = self.llm_engine.model_config.supported_tasks

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

299
300
301
302
303
        # Load the Input/Output processor plugin if any
        io_processor_plugin = self.llm_engine.model_config.io_processor_plugin
        self.io_processor = get_io_processor(self.llm_engine.vllm_config,
                                             io_processor_plugin)

304
305
306
307
308
309
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
310
311

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
312
        tokenizer_group = self.llm_engine.get_tokenizer_group()
313

314
315
316
317
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
318
            tokenizer_group.tokenizer = tokenizer
319
        else:
320
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
321

322
    def get_default_sampling_params(self) -> SamplingParams:
323
324
325
326
327
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
328
329
        return SamplingParams()

330
331
332
333
334
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
335
        *,
336
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
337
338
339
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341
        """Generates the completions for the input prompts.

342
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
343
344
345
346
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
347
            prompts: The prompts to the LLM. You may pass a sequence of prompts
348
                for batch inference. See [PromptType][vllm.inputs.PromptType]
349
                for more details about the format of each prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
350
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
351
352
353
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
354
                prompts and it is paired one by one with the prompt.
355
356
357
358
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
359
            lora_request: LoRA request to use for generation, if any.
360
361
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
362
363

        Returns:
364
            A list of `RequestOutput` objects containing the
365
            generated completions in the same order as the input prompts.
366

367
368
369
370
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
371
        """
372
373
374
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
375
376
377
378
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
379

380
381
        if sampling_params is None:
            # Use default sampling params.
382
            sampling_params = self.get_default_sampling_params()
383

384
385
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
386
            prompts, lora_request)
387

388
        self._validate_and_add_requests(
389
            prompts=prompts,
390
            params=sampling_params,
391
            use_tqdm=use_tqdm,
392
            lora_request=lora_request,
393
394
            priority=priority,
        )
395

396
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
397
        return self.engine_class.validate_outputs(outputs, RequestOutput)
398

399
    def _get_modality_specific_lora_reqs(
400
            self, prompts: Union[PromptType, Sequence[PromptType]],
401
402
403
404
405
406
407
408
409
410
411
412
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

413
414
        if not isinstance(prompts, Sequence):
            prompts = [prompts]
415

416
        optional_loras = ([lora_request] * len(prompts)
417
418
419
420
421
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
422
                prompt,
423
424
                opt_lora_req,
                lora_config.default_mm_loras,
425
            ) for prompt, opt_lora_req in zip(prompts, optional_loras)
426
427
        ]

428
    def _resolve_single_prompt_mm_lora(self, prompt: PromptType,
429
430
431
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
432
433
        if (not default_mm_loras or not isinstance(prompt, dict)
                or "multi_modal_data" not in prompt):
434
435
            return lora_request

436
        prompt = cast(Union[TextPrompt, TokensPrompt], prompt)
437

438
439
        intersection = set(prompt["multi_modal_data"].keys()) \
            .intersection(default_mm_loras.keys())
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

474
    def collective_rpc(self,
475
                       method: Union[str, Callable[..., _R]],
476
                       timeout: Optional[float] = None,
477
478
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
479
480
481
482
483
484
485
486
487
488
489
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
490
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
491
492
493
494
495
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
496

497
498
499
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
500
        """
501
502

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
503
504

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
505
        """
506
507
        Run a function directly on the model inside each worker,
        returning the result for each of them.
508
        """
509
510
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
511

512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

528
529
    def beam_search(
        self,
530
        prompts: list[Union[TokensPrompt, TextPrompt]],
531
        params: BeamSearchParams,
532
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
533
        use_tqdm: bool = False,
534
        concurrency_limit: Optional[int] = None,
535
    ) -> list[BeamSearchOutput]:
536
537
538
539
540
541
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
542
            params: The beam search parameters.
543
            lora_request: LoRA request to use for generation, if any.
544
            use_tqdm: Whether to use tqdm to display the progress bar.
545
546
            concurrency_limit: The maximum number of concurrent requests.
                If None, the number of concurrent requests is unlimited.
547
        """
548
549
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
550
551
552
553
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
554
555
        length_penalty = params.length_penalty

556
557
558
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

559
560
561
562
563
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
564

565
566
567
568
569
570
571
572
573
        if use_tqdm and concurrency_limit is not None:
            logger.warning(
                "Progress bar is not supported when using concurrency_limit. "
                "Disabling progress bar.")
            use_tqdm = False

        if concurrency_limit is None:
            concurrency_limit = len(prompts)

574
575
576
577
578
579
580
581
582
583
584
585
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
586

587
588
589
590
591
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
592
                                            temperature=temperature)
593
        instances: list[BeamSearchInstance] = []
594

595
        for lora_req, prompt in zip(lora_requests, prompts):
596
597
598
599
600
601
602
603
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

604
605
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
606
607
608
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
609

610
            instances.append(
611
612
613
614
615
616
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
617

618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
        for prompt_start in range(0, len(prompts), concurrency_limit):
            instances_batch = instances[prompt_start:prompt_start +
                                        concurrency_limit]

            token_iter = range(max_tokens)
            if use_tqdm:
                token_iter = tqdm(token_iter,
                                  desc="Beam search",
                                  unit="token",
                                  unit_scale=False)
                logger.warning(
                    "The progress bar shows the upper bound on token steps and "
                    "may finish early due to stopping conditions. It does not "
                    "reflect instance-level progress.")
            for _ in token_iter:
                all_beams: list[BeamSearchSequence] = list(
                    sum((instance.beams for instance in instances_batch), []))
                pos = [0] + list(
                    itertools.accumulate(
                        len(instance.beams) for instance in instances_batch))
                instance_start_and_end: list[tuple[int, int]] = list(
                    zip(pos[:-1], pos[1:]))

                if len(all_beams) == 0:
                    break

                # create corresponding batch entries for prompt & optional lora
                prompts_batch, lora_req_batch = zip(
                    *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                      for beam in all_beams])

                # only runs for one step
                # we don't need to use tqdm here
                output = self.generate(prompts_batch,
                                       sampling_params=beam_search_params,
                                       use_tqdm=False,
                                       lora_request=lora_req_batch)

                for (start, end), instance in zip(instance_start_and_end,
                                                  instances_batch):
                    instance_new_beams = []
                    for i in range(start, end):
                        current_beam = all_beams[i]
                        result = output[i]

                        if result.outputs[0].logprobs is not None:
                            # if `result.outputs[0].logprobs` is None, it means
                            # the sequence is completed because of the
                            # max-model-len or abortion. we don't need to add
                            # it to the new beams.
                            logprobs = result.outputs[0].logprobs[0]
                            for token_id, logprob_obj in logprobs.items():
                                new_beam = BeamSearchSequence(
                                    tokens=current_beam.tokens + [token_id],
                                    logprobs=current_beam.logprobs +
                                    [logprobs],
                                    lora_request=current_beam.lora_request,
                                    cum_logprob=current_beam.cum_logprob +
                                    logprob_obj.logprob,
                                    multi_modal_data=current_beam.
                                    multi_modal_data,
                                    mm_processor_kwargs=current_beam.
                                    mm_processor_kwargs)

                                if token_id == tokenizer.eos_token_id and \
                                    not ignore_eos:
                                    instance.completed.append(new_beam)
                                else:
                                    instance_new_beams.append(new_beam)
                    sorted_beams = sorted(instance_new_beams,
                                          key=sort_beams_key,
                                          reverse=True)
                    instance.beams = sorted_beams[:beam_width]
691
692
693
694
695

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
696
                                      key=sort_beams_key,
697
698
699
700
701
702
703
704
705
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

706
    def preprocess_chat(
nunjunj's avatar
nunjunj committed
707
        self,
708
709
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
710
711
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
712
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
713
        add_generation_prompt: bool = True,
714
        continue_final_message: bool = False,
715
        tools: Optional[list[dict[str, Any]]] = None,
716
        chat_template_kwargs: Optional[dict[str, Any]] = None,
717
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
718
    ) -> list[TokensPrompt]:
nunjunj's avatar
nunjunj committed
719
        """
720
721
        Generate prompt for a chat conversation. The pre-processed
        prompt can then be used as input for the other LLM methods.
nunjunj's avatar
nunjunj committed
722

723
        Refer to `chat` for a complete description of the arguments.
nunjunj's avatar
nunjunj committed
724
        Returns:
725
726
727
            A list of `TokensPrompts` objects containing the tokenized
            prompt after chat template interpolation, and the
            pre-processed multi-modal inputs.
nunjunj's avatar
nunjunj committed
728
        """
729
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
730

731
732
        # Handle multi and single conversations
        if is_list_of(messages, list):
733
734
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
735
                                    messages)
736
        else:
737
            # messages is list[...]
738
            list_of_messages = [
739
                cast(list[ChatCompletionMessageParam], messages)
740
            ]
741

742
        tokenizer = self.get_tokenizer(lora_request)
743
744
745
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
746
            tools,
747
748
            chat_template_content_format,
            tokenizer,
749
            model_config=model_config,
750
751
        )

752
753
754
755
756
757
758
759
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

760
        prompts: list[TokensPrompt] = []
761
762

        for msgs in list_of_messages:
763
764
765
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
766
            conversation, mm_data, mm_uuids = parse_chat_messages(
767
768
769
770
771
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
772
773

            if isinstance(tokenizer, MistralTokenizer):
774
                prompt_token_ids = apply_mistral_chat_template(
775
776
                    tokenizer,
                    messages=msgs,
777
                    **_chat_template_kwargs,
778
779
                )
            else:
780
                prompt_str = apply_hf_chat_template(
781
                    tokenizer=tokenizer,
782
                    conversation=conversation,
783
                    model_config=model_config,
784
                    **_chat_template_kwargs,
785
                )
786
787
788
789
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
790

791
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
792
793
794
795

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

796
797
798
            if mm_uuids is not None:
                prompt["multi_modal_uuids"] = mm_uuids

799
800
801
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

802
            prompts.append(prompt)
803

804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
        return prompts

    def chat(
        self,
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
        sampling_params: Optional[Union[SamplingParams,
                                        list[SamplingParams]]] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
        add_generation_prompt: bool = True,
        continue_final_message: bool = False,
        tools: Optional[list[dict[str, Any]]] = None,
        chat_template_kwargs: Optional[dict[str, Any]] = None,
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
        """
        Generate responses for a chat conversation.

        The chat conversation is converted into a text prompt using the
        tokenizer and calls the [generate][vllm.LLM.generate] method to generate
        the responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.

        Args:
            messages: A list of conversations or a single conversation.

                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.

            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
                If not provided, the model's default chat template will be used.
            chat_template_content_format: The format to render message content.

                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`

            add_generation_prompt: If True, adds a generation template
                to each message.
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be
                `True` if `add_generation_prompt` is also `True`.
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.

        Returns:
            A list of `RequestOutput` objects containing the generated
            responses in the same order as the input messages.
        """

        prompts = self.preprocess_chat(
            messages=messages,
            lora_request=lora_request,
            chat_template=chat_template,
            chat_template_content_format=chat_template_content_format,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
            chat_template_kwargs=chat_template_kwargs,
            mm_processor_kwargs=mm_processor_kwargs,
        )

nunjunj's avatar
nunjunj committed
885
        return self.generate(
886
            prompts,
887
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
888
889
890
891
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

892
893
    def encode(
        self,
894
        prompts: Union[PromptType, Sequence[PromptType], DataPrompt],
895
896
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
897
        *,
898
        truncate_prompt_tokens: Optional[int] = None,
899
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
900
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
901
        pooling_task: PoolingTask = "encode",
902
        tokenization_kwargs: Optional[dict[str, Any]] = None,
903
    ) -> list[PoolingRequestOutput]:
904
905
        """Apply pooling to the hidden states corresponding to the input
        prompts.
906

907
        This class automatically batches the given prompts, considering
908
909
910
911
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
912
            prompts: The prompts to the LLM. You may pass a sequence of prompts
913
                for batch inference. See [PromptType][vllm.inputs.PromptType]
914
                for more details about the format of each prompt.
915
916
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
917
918
919
920
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
921
            lora_request: LoRA request to use for generation, if any.
922
            pooling_task: Override the pooling task to use.
923
924
            tokenization_kwargs: overrides tokenization_kwargs set in
                pooling_params
925
926

        Returns:
927
            A list of `PoolingRequestOutput` objects containing the
928
            pooled hidden states in the same order as the input prompts.
929

930
931
932
933
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
934
        """
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
        if pooling_task is None:
            if "embed" in self.supported_tasks:
                pooling_task = "embed"
            else:
                pooling_task = "encode"

            logger.warning_once(
                "`LLM.encode` is currently using `pooling_task = %s`.\n"
                "Please use one of the more specific methods or set the "
                "task directly when using `LLM.encode`:\n"
                "  - For embeddings, use `LLM.embed(...)` "
                "or `pooling_task=\"embed\"`.\n"
                "  - For classification logits, use `LLM.classify(...)` "
                "or `pooling_task=\"classify\"`.\n"
                "  - For rewards, use `LLM.reward(...)` "
                "or `pooling_task=\"reward\"`\n"
                "  - For similarity scores, use `LLM.score(...)`.",
                pooling_task)

954
955
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
956
        if runner_type != "pooling":
957
958
959
960
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
961

962
963
964
965
        if pooling_task not in self.supported_tasks:
            raise ValueError(
                f"pooling_task must be one of {self.supported_tasks}.")

966
967
968
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
969

970
971
972
973
974
        for param in as_iter(pooling_params):
            param.verify(pooling_task, model_config)
            # for backwards compatibility
            if truncate_prompt_tokens is not None:
                param.truncate_prompt_tokens = truncate_prompt_tokens
975

976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
        io_processor_prompt = False
        if isinstance(prompts, dict) and "data" in prompts:
            io_processor_prompt = True
            if self.io_processor is None:
                raise ValueError(
                    "No IOProcessor plugin installed. Please refer "
                    "to the documentation and to the "
                    "'prithvi_geospatial_mae_io_processor' "
                    "offline inference example for more details.")

            # Validate the request data is valid for the loaded plugin
            validated_prompt = self.io_processor.parse_request(prompts)

            # obtain the actual model prompts from the pre-processor
            prompts = self.io_processor.pre_process(prompt=validated_prompt)

992
        self._validate_and_add_requests(
993
            prompts=prompts,
994
            params=pooling_params,
995
            use_tqdm=use_tqdm,
996
            lora_request=lora_request,
997
998
        )

999
        outputs = self._run_engine(use_tqdm=use_tqdm)
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017

        model_outputs = self.engine_class.validate_outputs(
            outputs, PoolingRequestOutput)

        if io_processor_prompt:
            # get the post-processed model outputs
            assert self.io_processor is not None
            processed_outputs = self.io_processor.post_process(
                model_output=model_outputs)

            return [
                PoolingRequestOutput[Any](request_id="",
                                          outputs=processed_outputs,
                                          prompt_token_ids=[],
                                          finished=True)
            ]
        else:
            return model_outputs
1018

1019
1020
1021
1022
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1023
        truncate_prompt_tokens: Optional[int] = None,
1024
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1025
1026
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1027
1028
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1029
1030
1031
1032
1033
1034
1035
1036
1037
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1038
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1039
                for more details about the format of each prompt.
1040
1041
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1042
1043
1044
1045
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1046
1047
1048
            lora_request: LoRA request to use for generation, if any.

        Returns:
1049
            A list of `EmbeddingRequestOutput` objects containing the
1050
1051
            embedding vectors in the same order as the input prompts.
        """
1052
        if "embed" not in self.supported_tasks:
1053
1054
1055
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1056

1057
1058
1059
1060
1061
1062
1063
1064
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1065
1066
1067
1068
1069
1070
1071

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        *,
1072
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1073
1074
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1075
1076
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1077
1078
1079
1080
1081
1082
1083
1084
1085
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1086
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1087
                for more details about the format of each prompt.
1088
1089
1090
1091
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1092
            lora_request: LoRA request to use for generation, if any.
1093
1094
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1095
        Returns:
1096
            A list of `ClassificationRequestOutput` objects containing the
1097
1098
            embedding vectors in the same order as the input prompts.
        """
1099
        if "classify" not in self.supported_tasks:
1100
            raise ValueError(
1101
                "Classification API is not supported by this model. "
1102
                "Try converting the model using `--convert classify`.")
1103

1104
1105
1106
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
1107
            pooling_params=pooling_params,
1108
1109
1110
            lora_request=lora_request,
            pooling_task="classify",
        )
1111
1112
1113

        return [ClassificationRequestOutput.from_base(item) for item in items]

1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
    def reward(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[PoolingRequestOutput]:
        """
        Generate rewards for each prompt.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1131
                for more details about the format of each prompt.
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
            lora_request: LoRA request to use for generation, if any.
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
        Returns:
            A list of `PoolingRequestOutput` objects containing the
            pooled hidden states in the same order as the input prompts.
        """

        return self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_params=pooling_params,
            truncate_prompt_tokens=truncate_prompt_tokens,
            pooling_task="encode",
        )

1153
1154
1155
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1156
1157
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1158
        truncate_prompt_tokens: Optional[int] = None,
1159
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1160
        pooling_params: Optional[PoolingParams] = None,
1161
1162
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1163

1164
        encoded_output: list[PoolingRequestOutput] = self.encode(
1165
            text_1 + text_2,
1166
            truncate_prompt_tokens=truncate_prompt_tokens,
1167
1168
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1169
            pooling_params=pooling_params,
1170
1171
            pooling_task="embed",
        )
1172

1173
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1174
            0:len(text_1)]
1175
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1176
            len(text_1):]
1177
1178
1179
1180

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1181
1182
1183
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1184
1185
1186
1187
1188
1189
1190

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1191
        tokenizer: AnyTokenizer,
1192
1193
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1194
        truncate_prompt_tokens: Optional[int] = None,
1195
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1196
        pooling_params: Optional[PoolingParams] = None,
1197
1198
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1199
        model_config = self.llm_engine.model_config
1200
1201
1202

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1203
                "Score API is not supported for Mistral tokenizer")
1204

1205
1206
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1207

1208
1209
1210
1211
1212
        if pooling_params is None:
            pooling_params = PoolingParams(task="score")

        model_config = self.llm_engine.model_config
        pooling_params.verify("score", model_config)
1213
        pooling_params_list = list[PoolingParams]()
1214

1215
        tokenization_kwargs: dict[str, Any] = {}
1216
1217

        _validate_truncation_size(model_config.max_model_len,
1218
                                  truncate_prompt_tokens, tokenization_kwargs)
1219

1220
        prompts = list[PromptType]()
1221

1222
1223
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1224
        model_config = self.llm_engine.model_config
1225

1226
1227
1228
1229
1230
1231
1232
1233
1234
        for q, d in input_pairs:
            _, engine_prompt = get_score_prompt(
                model_config=model_config,
                data_1=q,
                data_2=d,
                tokenizer=tokenizer,
                tokenization_kwargs=tokenization_kwargs,
            )

1235
            if (token_type_ids := engine_prompt.pop("token_type_ids", None)):
1236
1237
1238
1239
1240
1241
1242
                params = pooling_params.clone()
                compressed = compress_token_type_ids(token_type_ids)
                params.extra_kwargs = {"compressed_token_type_ids": compressed}
                pooling_params_list.append(params)
            else:
                pooling_params_list.append(pooling_params)

1243
            prompts.append(engine_prompt)
1244
1245

        self._validate_and_add_requests(
1246
            prompts=prompts,
1247
            params=pooling_params_list,
1248
            use_tqdm=use_tqdm,
1249
1250
1251
1252
1253
1254
1255
1256
1257
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1258
1259
    def score(
        self,
1260
1261
1262
1263
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1264
        /,
1265
        *,
1266
        truncate_prompt_tokens: Optional[int] = None,
1267
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1268
        pooling_params: Optional[PoolingParams] = None,
1269
1270
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1271
1272
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1273

1274
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1275
1276
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1277
        The input pairs are used to build a list of prompts for the
1278
1279
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1280
1281
1282
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
1283
        appropriate multi-modal models. For multi-modal inputs, ensure the
1284
        prompt structure matches the model's expected input format.
1285
1286

        Args:
1287
1288
1289
            data_1: Can be a single prompt, a list of prompts or
                `ScoreMultiModalParam`, which can contain either text or
                multi-modal data. When a list, it must have the same length as
1290
                the `data_2` list.
1291
            data_2: The data to pair with the query to form the input to
1292
                the LLM. Can be text or multi-modal data. See [PromptType]
1293
                [vllm.inputs.PromptType] for more details about the format of
1294
                each prompt.
1295
1296
1297
1298
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1299
            lora_request: LoRA request to use for generation, if any.
1300
1301
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1302
        Returns:
1303
            A list of `ScoringRequestOutput` objects containing the
1304
1305
            generated scores in the same order as the input prompts.
        """
1306
1307
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1308
        if runner_type != "pooling":
1309
1310
1311
1312
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1313

1314
1315
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1316
            raise ValueError("Score API is not supported by this model. "
1317
1318
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1319

1320
        if (model_config.is_cross_encoder
1321
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1322
            raise ValueError("Score API is only enabled for num_labels == 1.")
1323
1324
1325
1326

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1327
        tokenizer = self.get_tokenizer()
1328

1329
        if not model_config.is_multimodal_model:
1330
1331
1332
1333
1334

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1335
1336
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1377

1378
        if model_config.is_cross_encoder:
1379
1380
1381
1382
1383
1384
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1385
                pooling_params,
1386
                lora_request)
1387
        else:
1388
1389
            return self._embedding_score(
                tokenizer,
1390
1391
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1392
1393
                truncate_prompt_tokens,
                use_tqdm,
1394
                pooling_params,
1395
                lora_request)
1396

1397
1398
1399
1400
1401
1402
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1403
1404
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1405

1406
1407
1408
1409
1410
1411
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1412
        Args:
1413
1414
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1415
                is forgotten. Level 1 sleep is good for sleeping and waking
1416
1417
1418
1419
1420
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1421
                sleep is good for sleeping and waking up the engine to run a
1422
                different model or update the model, where previous model
1423
                weights are not needed. It reduces CPU memory pressure.
1424
        """
1425
        self.reset_prefix_cache()
1426
1427
        self.llm_engine.sleep(level=level)

1428
    def wake_up(self, tags: Optional[list[str]] = None):
1429
        """
1430
1431
        Wake up the engine from sleep mode. See the [sleep][vllm.LLM.sleep]
        method for more details.
1432

1433
        Args:
1434
1435
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1436
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1437
                wake_up should be called with all tags (or None) before the
1438
1439
1440
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1441

1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1456
1457
    def _validate_and_add_requests(
        self,
1458
        prompts: Union[PromptType, Sequence[PromptType]],
1459
1460
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1461
        *,
1462
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1463
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1464
        priority: Optional[list[int]] = None,
1465
    ) -> None:
1466
        if isinstance(prompts, (str, dict)):
1467
            # Convert a single prompt to a list.
1468
            prompts = [prompts]
1469

1470
        num_requests = len(prompts)
1471
        if isinstance(params, Sequence) and len(params) != num_requests:
1472
            raise ValueError("The lengths of prompts and params "
1473
                             "must be the same.")
1474
        if isinstance(lora_request,
1475
                      Sequence) and len(lora_request) != num_requests:
1476
1477
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1478

1479
        for sp in params if isinstance(params, Sequence) else (params, ):
1480
1481
1482
            if isinstance(sp, SamplingParams):
                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1483

Zhuohan Li's avatar
Zhuohan Li committed
1484
        # Add requests to the engine.
1485
1486
        it = prompts
        if use_tqdm:
1487
1488
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1489

1490
1491
        model_config = self.llm_engine.model_config

1492
        for i, prompt in enumerate(it):
1493
1494
1495
1496
1497
1498
1499
1500

            param = params[i] if isinstance(params, Sequence) else params

            tokenization_kwargs: dict[str, Any] = {}
            _validate_truncation_size(model_config.max_model_len,
                                      param.truncate_prompt_tokens,
                                      tokenization_kwargs)

1501
            self._add_request(
1502
                prompt,
1503
                params[i] if isinstance(params, Sequence) else params,
1504
                tokenization_kwargs=tokenization_kwargs,
1505
1506
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1507
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1508
            )
1509

1510
    def _add_request(
nunjunj's avatar
nunjunj committed
1511
        self,
1512
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1513
        params: Union[SamplingParams, PoolingParams],
1514
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1515
        lora_request: Optional[LoRARequest] = None,
1516
        priority: int = 0,
1517
1518
    ) -> None:
        request_id = str(next(self.request_counter))
1519
1520
        self.llm_engine.add_request(
            request_id,
1521
            prompt,
1522
1523
            params,
            lora_request=lora_request,
1524
            tokenization_kwargs=tokenization_kwargs,
1525
            priority=priority,
nunjunj's avatar
nunjunj committed
1526
        )
1527

1528
    def _run_engine(
1529
1530
1531
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1532
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1533
1534
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1535
            num_requests = self.llm_engine.get_num_unfinished_requests()
1536
1537
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1538
1539
1540
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1541
1542
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1543
            )
1544

Zhuohan Li's avatar
Zhuohan Li committed
1545
        # Run the engine.
1546
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1547
1548
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1549
1550
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1551
            for output in step_outputs:
1552
                if output.finished:
1553
1554
                    outputs.append(output)
                    if use_tqdm:
1555
1556
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1557
                            n = len(output.outputs)
1558
                            assert output.prompt_token_ids is not None
1559
                            total_in_toks += len(output.prompt_token_ids) * n
1560
1561
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1562
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1563
1564
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1565
1566
1567
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1568
                            pbar.update(n)
1569
1570
                        else:
                            pbar.update(1)
1571
1572
                        if pbar.n == num_requests:
                            pbar.refresh()
1573

1574
1575
        if use_tqdm:
            pbar.close()
1576
1577
1578
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1579
        return sorted(outputs, key=lambda x: int(x.request_id))