"vllm/tool_parsers/kimi_k2_tool_parser.py" did not exist on "1f0c75afa95303fcb628861f040199090e82004d"
llm.py 69.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
18
19
                              BeamSearchSequence,
                              create_sort_beams_key_function)
20
21
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
22
23
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
24
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
25
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
26
                                         ChatTemplateContentFormatOption,
27
28
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
29
30
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
31
32
from vllm.entrypoints.score_utils import (_cosine_similarity,
                                          _validate_score_input_lens)
33
from vllm.entrypoints.utils import _validate_truncation_size
34
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
35
from vllm.inputs.parse import parse_and_batch_prompt
36
from vllm.logger import init_logger
37
from vllm.lora.request import LoRARequest
38
39
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
40
from vllm.model_executor.layers.quantization import QuantizationMethods
41
42
43
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
44
from vllm.pooling_params import PoolingParams
45
from vllm.prompt_adapter.request import PromptAdapterRequest
46
47
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
48
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
49
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
50
from vllm.usage.usage_lib import UsageContext
51
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
52

53
54
55
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

56
57
logger = init_logger(__name__)

58
59
_R = TypeVar("_R", default=Any)

60
61

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
62
63
64
65
66
67
68
69
70
71
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
72
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
73
74
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
75
76
77
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
78
79
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
80
81
82
83
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
84
85
86
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
90
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
91
        quantization: The method used to quantize the model weights. Currently,
92
            we support "awq", "gptq", and "fp8" (experimental).
93
94
95
96
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
97
98
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
99
100
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
101
102
103
104
105
106
107
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
108
109
110
111
112
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
113
114
115
116
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
117
118
119
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
120
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
121
            When a sequence has context length larger than this, we fall back
122
123
124
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
125
126
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
127
128
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
129
        hf_token: The token to use as HTTP bearer authorization for remote files
130
            . If `True`, will use the token generated when running
131
            `huggingface-cli login` (stored in `~/.huggingface`).
132
133
134
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
135
136
137
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
138
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
139

140
141
    Note:
        This class is intended to be used for offline inference. For online
142
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
143
    """
144

145
    DEPRECATE_LEGACY: ClassVar[bool] = True
146
147
148
149
150
151
152
153
154
155
156
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

157
158
159
    def __init__(
        self,
        model: str,
160
161
        *,
        task: TaskOption = "auto",
162
        tokenizer: Optional[str] = None,
163
        tokenizer_mode: TokenizerMode = "auto",
164
        skip_tokenizer_init: bool = False,
165
        trust_remote_code: bool = False,
166
        allowed_local_media_path: str = "",
167
        tensor_parallel_size: int = 1,
168
169
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
170
        revision: Optional[str] = None,
171
        tokenizer_revision: Optional[str] = None,
172
        seed: Optional[int] = None,
173
        gpu_memory_utilization: float = 0.9,
174
        swap_space: float = 4,
175
        cpu_offload_gb: float = 0,
176
        enforce_eager: bool = False,
177
        max_seq_len_to_capture: int = 8192,
178
        disable_custom_all_reduce: bool = False,
179
        disable_async_output_proc: bool = False,
180
        hf_token: Optional[Union[bool, str]] = None,
181
        hf_overrides: Optional[HfOverrides] = None,
182
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
183
        override_pooler_config: Optional[PoolerConfig] = None,
184
185
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
186
187
        **kwargs,
    ) -> None:
188
        """LLM constructor."""
189

190
191
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
192

193
194
195
196
197
198
199
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

217
218
219
        if hf_overrides is None:
            hf_overrides = {}

220
        if compilation_config is not None:
221
222
223
224
225
226
227
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
228
229
            else:
                compilation_config_instance = compilation_config
230
        else:
231
            compilation_config_instance = CompilationConfig()
232

Zhuohan Li's avatar
Zhuohan Li committed
233
        engine_args = EngineArgs(
234
            model=model,
235
            task=task,
236
            tokenizer=tokenizer,
237
            tokenizer_mode=tokenizer_mode,
238
            skip_tokenizer_init=skip_tokenizer_init,
239
            trust_remote_code=trust_remote_code,
240
            allowed_local_media_path=allowed_local_media_path,
241
242
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
243
            quantization=quantization,
244
            revision=revision,
245
            tokenizer_revision=tokenizer_revision,
246
247
248
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
249
            cpu_offload_gb=cpu_offload_gb,
250
            enforce_eager=enforce_eager,
251
            max_seq_len_to_capture=max_seq_len_to_capture,
252
            disable_custom_all_reduce=disable_custom_all_reduce,
253
            disable_async_output_proc=disable_async_output_proc,
254
            hf_token=hf_token,
255
            hf_overrides=hf_overrides,
256
            mm_processor_kwargs=mm_processor_kwargs,
257
            override_pooler_config=override_pooler_config,
258
            compilation_config=compilation_config_instance,
259
260
            **kwargs,
        )
261
262
263
264
265

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
266

267
        self.request_counter = Counter()
268
        self.default_sampling_params: Union[dict[str, Any], None] = None
269

270
271
272
273
274
275
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
276
277

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
278
        tokenizer_group = self.llm_engine.get_tokenizer_group()
279

280
281
282
283
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
284
            tokenizer_group.tokenizer = tokenizer
285
        else:
286
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
287

288
    def get_default_sampling_params(self) -> SamplingParams:
289
290
291
292
293
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
294
295
        return SamplingParams()

296
297
298
299
300
301
302
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
303
        *,
304
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
305
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
306
307
308
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
309
    ) -> list[RequestOutput]:
310
311
        ...

312
    @overload  # LEGACY: single (prompt + optional token ids)
313
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
314
315
316
317
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
318
319
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
320
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
321
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
322
323
324
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
325
    ) -> list[RequestOutput]:
326
327
328
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
329
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
330
331
    def generate(
        self,
332
        prompts: list[str],
333
        sampling_params: Optional[Union[SamplingParams,
334
335
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
336
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
337
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
338
339
340
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
341
    ) -> list[RequestOutput]:
342
343
344
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
345
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
346
347
348
349
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
350
                                        list[SamplingParams]]] = None,
351
        *,
352
        prompt_token_ids: list[int],
353
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
354
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
355
356
357
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
358
    ) -> list[RequestOutput]:
359
360
361
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
362
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
363
364
    def generate(
        self,
365
        prompts: Optional[list[str]] = None,
366
        sampling_params: Optional[Union[SamplingParams,
367
                                        list[SamplingParams]]] = None,
368
        *,
369
        prompt_token_ids: list[list[int]],
370
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
371
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
372
373
374
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
375
    ) -> list[RequestOutput]:
376
377
378
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
379
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
380
381
382
383
    def generate(
        self,
        prompts: None,
        sampling_params: None,
384
        prompt_token_ids: Union[list[int], list[list[int]]],
385
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
386
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
387
388
389
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
390
    ) -> list[RequestOutput]:
391
392
        ...

nunjunj's avatar
nunjunj committed
393
394
395
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
396
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
397
    )
398
399
    def generate(
        self,
400
        prompts: Union[Union[PromptType, Sequence[PromptType]],
401
                       Optional[Union[str, list[str]]]] = None,
402
403
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
404
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
405
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
406
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
407
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
408
        guided_options_request: Optional[Union[LLMGuidedOptions,
409
                                               GuidedDecodingRequest]] = None,
410
411
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
412
413
        """Generates the completions for the input prompts.

414
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
417
418
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
419
            prompts: The prompts to the LLM. You may pass a sequence of prompts
420
                for batch inference. See [PromptType][vllm.inputs.PromptType]
421
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
422
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
423
424
425
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
426
                prompts and it is paired one by one with the prompt.
427
428
429
430
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
431
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
432
            prompt_adapter_request: Prompt Adapter request to use for
433
                generation, if any.
434
435
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
436
437

        Returns:
438
            A list of `RequestOutput` objects containing the
439
            generated completions in the same order as the input prompts.
440

441
442
443
444
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
445
        """
446
        runner_type = self.llm_engine.model_config.runner_type
447
        if runner_type not in ["generate", "transcription"]:
448
            messages = [
449
                "LLM.generate() is only supported for (conditional) generation "
450
451
452
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

453
454
455
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
456
                messages.append(
457
458
459
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
460
461

            raise ValueError(" ".join(messages))
462

463
        if prompt_token_ids is not None:
464
            parsed_prompts = self._convert_v1_inputs(
465
                prompts=cast(Optional[Union[str, list[str]]], prompts),
466
467
468
                prompt_token_ids=prompt_token_ids,
            )
        else:
469
470
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
471

472
473
474
475
476
477
478
479
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

480
481
        if sampling_params is None:
            # Use default sampling params.
482
            sampling_params = self.get_default_sampling_params()
483

484
485
486
487
488
489
490
        tokenization_kwargs: dict[str, Any] = {}
        truncate_prompt_tokens = None
        if isinstance(sampling_params, SamplingParams):
            truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

491
        self._validate_and_add_requests(
492
            prompts=parsed_prompts,
493
            params=sampling_params,
494
            use_tqdm=use_tqdm,
495
            lora_request=lora_request,
496
            prompt_adapter_request=prompt_adapter_request,
497
            guided_options=guided_options_request,
498
            tokenization_kwargs=tokenization_kwargs,
499
500
            priority=priority,
        )
501

502
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
503
        return self.engine_class.validate_outputs(outputs, RequestOutput)
504

505
    def collective_rpc(self,
506
                       method: Union[str, Callable[..., _R]],
507
                       timeout: Optional[float] = None,
508
509
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
510
511
512
513
514
515
516
517
518
519
520
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
521
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
522
523
524
525
526
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
527

528
529
530
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
531
        """
532
533

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
534
535

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
536
        """
537
538
        Run a function directly on the model inside each worker,
        returning the result for each of them.
539
        """
540
541
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
542

543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

559
560
    def beam_search(
        self,
561
        prompts: list[Union[TokensPrompt, TextPrompt]],
562
        params: BeamSearchParams,
563
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
564
        use_tqdm: bool = False,
565
    ) -> list[BeamSearchOutput]:
566
567
568
569
570
571
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
572
            params: The beam search parameters.
573
            lora_request: LoRA request to use for generation, if any.
574
            use_tqdm: Whether to use tqdm to display the progress bar.
575
        """
576
577
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
578
579
580
581
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
582
583
        length_penalty = params.length_penalty

584
585
586
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

587
588
589
590
591
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
592

593
594
595
596
597
598
599
600
601
602
603
604
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
605

606
607
608
609
610
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
611
                                            temperature=temperature)
612
        instances: list[BeamSearchInstance] = []
613

614
        for lora_req, prompt in zip(lora_requests, prompts):
615
616
617
618
619
620
621
622
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

623
624
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
625
626
627
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
628

629
            instances.append(
630
631
632
633
634
635
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
636

637
638
639
640
641
642
643
644
645
646
647
648
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
649
            all_beams: list[BeamSearchSequence] = list(
650
651
652
653
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
654
            instance_start_and_end: list[tuple[int, int]] = list(
655
656
657
658
659
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

660
661
662
663
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
664
665
666
667
668

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
669
670
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
687
                                logprobs=current_beam.logprobs + [logprobs],
688
                                lora_request=current_beam.lora_request,
689
                                cum_logprob=current_beam.cum_logprob +
690
691
692
693
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
694
695
696
697
698
699
700

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
701
                                      key=sort_beams_key,
702
703
704
705
706
707
708
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
709
                                      key=sort_beams_key,
710
711
712
713
714
715
716
717
718
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
719
720
    def chat(
        self,
721
722
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
723
        sampling_params: Optional[Union[SamplingParams,
724
                                        list[SamplingParams]]] = None,
725
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
726
727
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
728
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
729
        add_generation_prompt: bool = True,
730
        continue_final_message: bool = False,
731
        tools: Optional[list[dict[str, Any]]] = None,
732
        chat_template_kwargs: Optional[dict[str, Any]] = None,
733
734
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
735
        """
736
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
737

738
        The chat conversation is converted into a text prompt using the
739
        tokenizer and calls the [generate][] method to generate the
740
741
742
743
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
744
745

        Args:
746
747
            messages: A list of conversations or a single conversation.

748
749
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
750

nunjunj's avatar
nunjunj committed
751
752
753
754
755
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
756
757
758
759
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
760
761
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
762
                If not provided, the model's default chat template will be used.
763
764
            chat_template_content_format: The format to render message content.

765
766
767
768
769
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
770

771
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
772
                to each message.
773
            continue_final_message: If True, continues the final message in
774
                the conversation instead of starting a new one. Cannot be
775
                `True` if `add_generation_prompt` is also `True`.
776
777
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
778
779
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
780
781

        Returns:
782
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
783
784
            responses in the same order as the input messages.
        """
785
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
786

787
788
        # Handle multi and single conversations
        if is_list_of(messages, list):
789
790
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
791
                                    messages)
792
        else:
793
            # messages is list[...]
794
            list_of_messages = [
795
                cast(list[ChatCompletionMessageParam], messages)
796
            ]
797

798
        tokenizer = self.get_tokenizer(lora_request)
799
800
801
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
802
            tools,
803
804
            chat_template_content_format,
            tokenizer,
805
            model_config=model_config,
806
807
        )

808
809
810
811
812
813
814
815
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

816
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
817
818

        for msgs in list_of_messages:
819
820
821
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
822
            conversation, mm_data = parse_chat_messages(
823
824
825
826
827
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
828
829

            if isinstance(tokenizer, MistralTokenizer):
830
                prompt_token_ids = apply_mistral_chat_template(
831
832
                    tokenizer,
                    messages=msgs,
833
                    **_chat_template_kwargs,
834
835
                )
            else:
836
                prompt_str = apply_hf_chat_template(
837
                    tokenizer=tokenizer,
838
                    conversation=conversation,
839
                    model_config=model_config,
840
                    **_chat_template_kwargs,
841
                )
842
843
844
845
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
846

847
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
848
849
850
851

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

852
853
854
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

855
            prompts.append(prompt)
856

nunjunj's avatar
nunjunj committed
857
        return self.generate(
858
            prompts,
859
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
860
861
862
863
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

864
865
866
867
868
869
870
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
871
        *,
872
        truncate_prompt_tokens: Optional[int] = None,
873
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
874
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
875
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
876
    ) -> list[PoolingRequestOutput]:
877
878
        ...

879
    @overload  # LEGACY: single (prompt + optional token ids)
880
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
881
882
883
884
885
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
886
        prompt_token_ids: Optional[list[int]] = None,
887
        truncate_prompt_tokens: Optional[int] = None,
888
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
889
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
890
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
891
    ) -> list[PoolingRequestOutput]:
892
        ...
893

894
    @overload  # LEGACY: multi (prompt + optional token ids)
895
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
896
897
    def encode(
        self,
898
        prompts: list[str],
899
        pooling_params: Optional[Union[PoolingParams,
900
                                       Sequence[PoolingParams]]] = None,
901
        prompt_token_ids: Optional[list[list[int]]] = None,
902
        truncate_prompt_tokens: Optional[int] = None,
903
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
904
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
905
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
906
    ) -> list[PoolingRequestOutput]:
907
908
909
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
910
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
911
912
913
914
915
916
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
917
        prompt_token_ids: list[int],
918
        truncate_prompt_tokens: Optional[int] = None,
919
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
920
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
921
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
922
    ) -> list[PoolingRequestOutput]:
923
924
925
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
926
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
927
928
    def encode(
        self,
929
        prompts: Optional[list[str]] = None,
930
931
932
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
933
        prompt_token_ids: list[list[int]],
934
        truncate_prompt_tokens: Optional[int] = None,
935
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
936
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
937
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
938
    ) -> list[PoolingRequestOutput]:
939
940
941
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
942
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
943
944
945
946
    def encode(
        self,
        prompts: None,
        pooling_params: None,
947
        prompt_token_ids: Union[list[int], list[list[int]]],
948
        truncate_prompt_tokens: Optional[int] = None,
949
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
950
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
951
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
952
    ) -> list[PoolingRequestOutput]:
953
954
        ...

nunjunj's avatar
nunjunj committed
955
956
957
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
958
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
959
    )
960
961
    def encode(
        self,
962
        prompts: Union[Union[PromptType, Sequence[PromptType]],
963
                       Optional[Union[str, list[str]]]] = None,
964
965
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
966
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
967
        truncate_prompt_tokens: Optional[int] = None,
968
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
969
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
970
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
971
    ) -> list[PoolingRequestOutput]:
972
973
        """Apply pooling to the hidden states corresponding to the input
        prompts.
974

975
        This class automatically batches the given prompts, considering
976
977
978
979
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
980
            prompts: The prompts to the LLM. You may pass a sequence of prompts
981
                for batch inference. See [PromptType][vllm.inputs.PromptType]
982
                for more details about the format of each prompts.
983
984
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
985
986
987
988
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
989
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
990
            prompt_adapter_request: Prompt Adapter request to use for
991
                generation, if any.
992
993

        Returns:
994
            A list of `PoolingRequestOutput` objects containing the
995
            pooled hidden states in the same order as the input prompts.
996

997
998
999
1000
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1001
        """
1002
1003
1004
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
1005

1006
1007
1008
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1009
                messages.append(
1010
1011
1012
1013
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1014
1015

            raise ValueError(" ".join(messages))
1016

1017
        if prompt_token_ids is not None:
1018
            parsed_prompts = self._convert_v1_inputs(
1019
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1020
1021
1022
                prompt_token_ids=prompt_token_ids,
            )
        else:
1023
1024
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1025

1026
1027
1028
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1029
1030
1031
1032
1033
        elif isinstance(pooling_params, PoolingParams):
            pooling_params.verify(self.llm_engine.model_config)
        else:
            for pooling_param in pooling_params:
                pooling_param.verify(self.llm_engine.model_config)
1034

1035
1036
1037
1038
        tokenization_kwargs: dict[str, Any] = {}
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

1039
        self._validate_and_add_requests(
1040
            prompts=parsed_prompts,
1041
            params=pooling_params,
1042
            use_tqdm=use_tqdm,
1043
            lora_request=lora_request,
1044
            tokenization_kwargs=tokenization_kwargs,
1045
            prompt_adapter_request=prompt_adapter_request,
1046
1047
        )

1048
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1049
        return self.engine_class.validate_outputs(outputs,
1050
                                                  PoolingRequestOutput)
1051

1052
1053
1054
1055
1056
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1057
        truncate_prompt_tokens: Optional[int] = None,
1058
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1059
1060
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1061
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1062
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1063
    ) -> list[EmbeddingRequestOutput]:
1064
1065
1066
1067
1068
1069
1070
1071
1072
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1073
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1074
                for more details about the format of each prompts.
1075
1076
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1077
1078
1079
1080
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1081
1082
1083
1084
1085
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1086
            A list of `EmbeddingRequestOutput` objects containing the
1087
1088
1089
1090
1091
1092
1093
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
1094
                            truncate_prompt_tokens=truncate_prompt_tokens,
1095
                            use_tqdm=use_tqdm,
1096
                            pooling_params=pooling_params,
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1107
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1108
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1109
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1110
    ) -> list[ClassificationRequestOutput]:
1111
1112
1113
1114
1115
1116
1117
1118
1119
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1120
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1121
                for more details about the format of each prompts.
1122
1123
1124
1125
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1126
1127
1128
1129
1130
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1131
            A list of `ClassificationRequestOutput` objects containing the
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

1145
1146
1147
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1148
1149
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1150
        truncate_prompt_tokens: Optional[int] = None,
1151
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1152
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1153
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1154
    ) -> list[ScoringRequestOutput]:
1155

1156
        encoded_output: list[PoolingRequestOutput] = self.encode(
1157
            text_1 + text_2,
1158
            truncate_prompt_tokens=truncate_prompt_tokens,
1159
1160
1161
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
1162

1163
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1164
            0:len(text_1)]
1165
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1166
            len(text_1):]
1167
1168
1169
1170

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1171
1172
1173
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1174
1175
1176
1177
1178
1179
1180

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1181
        tokenizer: AnyTokenizer,
1182
1183
        text_1: list[str],
        text_2: list[str],
1184
        truncate_prompt_tokens: Optional[int] = None,
1185
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1186
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1187
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1188
    ) -> list[ScoringRequestOutput]:
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]

        pooling_params = PoolingParams()

1201
        tokenization_kwargs: dict[str, Any] = {}
1202
1203
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1219
            use_tqdm=use_tqdm,
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1230
1231
1232
1233
1234
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1235
        *,
1236
        truncate_prompt_tokens: Optional[int] = None,
1237
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1238
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1239
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1240
    ) -> list[ScoringRequestOutput]:
1241
        """Generate similarity scores for all pairs `<text,text_pair>`.
1242

1243
1244
1245
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
        In the `1 - N` case the `text_1` sentence will be replicated `N`
        times to pair with the `text_2` sentences.
1246
        The input pairs are used to build a list of prompts for the
1247
1248
1249
1250
1251
1252
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1253
                case it has to have the same length as the `text_2` list
1254
            text_2: The texts to pair with the query to form the input
1255
                to the LLM. See [PromptType][vllm.inputs.PromptType] for
1256
                more details about the format of each prompts.
1257
1258
1259
1260
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1261
1262
1263
1264
1265
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1266
            A list of `ScoringRequestOutput` objects containing the
1267
1268
            generated scores in the same order as the input prompts.
        """
1269
1270
1271
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1272

1273
1274
1275
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1276
                messages.append(
1277
1278
1279
1280
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1281
1282
1283

            raise ValueError(" ".join(messages))

1284
        if self.llm_engine.model_config.task not in ("embed", "score"):
1285
            raise ValueError(
1286
                "Score API is only enabled for `--task embed or --task score`")
1287
1288
1289
1290

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1291
        tokenizer = self.get_tokenizer()
1292

1293
1294
1295
1296
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
1297
                                     "supported for scoring")
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
1309
        input_text_1: list[str] = [ensure_str(t) for t in text_1]
1310
1311
1312
1313

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
1314
        input_text_2: list[str] = [ensure_str(t) for t in text_2]
1315

1316
        _validate_score_input_lens(input_text_1, input_text_2)
1317

1318
        if self.llm_engine.model_config.is_cross_encoder:
1319
1320
            return self._cross_encoding_score(tokenizer, input_text_1,
                                              input_text_2,
1321
1322
1323
1324
                                              truncate_prompt_tokens, use_tqdm,
                                              lora_request,
                                              prompt_adapter_request)
        else:
1325
1326
1327
1328
1329
1330
1331
1332
            return self._embedding_score(
                tokenizer,
                input_text_1,  # type: ignore[arg-type]
                input_text_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1333

1334
1335
1336
1337
1338
1339
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1340
1341
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1342

1343
1344
1345
1346
1347
1348
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
        Args:
            level: The sleep level. Level 1 sleep will offload the model 
                weights and discard the kv cache. The content of kv cache 
                is forgotten. Level 1 sleep is good for sleeping and waking
                up the engine to run the same model again. The model weights 
                are backed up in CPU memory. Please make sure there's enough 
                CPU memory to store the model weights. Level 2 sleep will 
                discard both the model weights and the kv cache. The content 
                of both the model weights and kv cache is forgotten. Level 2 
                sleep is good for sleeping and waking up the engine to run a
                different model or update the model, where previous model 
                weights are not needed. It reduces CPU memory pressure.
1361
        """
1362
        self.reset_prefix_cache()
1363
1364
        self.llm_engine.sleep(level=level)

1365
    def wake_up(self, tags: Optional[list[str]] = None):
1366
        """
1367
        Wake up the engine from sleep mode. See the [sleep][] method
1368
1369
1370
1371
1372
        for more details.
        
        Args:
            tags: An optional list of tags to reallocate the engine memory 
                for specific memory allocations. Values must be in 
1373
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1374
1375
1376
1377
                wake_up should be called with all tags (or None) before the 
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1378

1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1393
1394
    # LEGACY
    def _convert_v1_inputs(
1395
        self,
1396
1397
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1398
1399
    ):
        # skip_tokenizer_init is now checked in engine
1400

1401
1402
1403
1404
1405
1406
1407
1408
1409
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1410
1411
1412
1413
1414
1415
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1416
1417
        if prompts is not None:
            num_requests = len(prompts)
1418
        elif prompt_token_ids is not None:
1419
            num_requests = len(prompt_token_ids)
1420
        parsed_prompts: list[PromptType] = []
1421
        for i in range(num_requests):
1422
            item: PromptType
1423

1424
            if prompts is not None:
1425
1426
1427
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1428
            else:
1429
                raise AssertionError
1430

1431
            parsed_prompts.append(item)
1432

1433
        return parsed_prompts
1434
1435
1436

    def _validate_and_add_requests(
        self,
1437
        prompts: Union[PromptType, Sequence[PromptType]],
1438
1439
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1440
        *,
1441
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1442
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1443
        prompt_adapter_request: Optional[PromptAdapterRequest],
1444
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1445
        guided_options: Optional[GuidedDecodingRequest] = None,
1446
        priority: Optional[list[int]] = None,
1447
    ) -> None:
1448
1449
1450
1451
1452
1453
1454
1455
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1456
        if isinstance(prompts, (str, dict)):
1457
            # Convert a single prompt to a list.
1458
            prompts = [prompts]
1459

1460
        num_requests = len(prompts)
1461
        if isinstance(params, Sequence) and len(params) != num_requests:
1462
            raise ValueError("The lengths of prompts and params "
1463
                             "must be the same.")
1464
        if isinstance(lora_request,
1465
                      Sequence) and len(lora_request) != num_requests:
1466
1467
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1468

1469
        for sp in params if isinstance(params, Sequence) else (params, ):
1470
            if isinstance(sp, SamplingParams):
1471
                self._add_guided_params(sp, guided_options)
1472
1473
1474

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1475

Zhuohan Li's avatar
Zhuohan Li committed
1476
        # Add requests to the engine.
1477
1478
        it = prompts
        if use_tqdm:
1479
1480
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1481
1482

        for i, prompt in enumerate(it):
1483
            self._add_request(
1484
                prompt,
1485
                params[i] if isinstance(params, Sequence) else params,
1486
                tokenization_kwargs=tokenization_kwargs,
1487
1488
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1489
                prompt_adapter_request=prompt_adapter_request,
1490
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1491
            )
1492

1493
    def _add_request(
nunjunj's avatar
nunjunj committed
1494
        self,
1495
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1496
        params: Union[SamplingParams, PoolingParams],
1497
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1498
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1499
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1500
        priority: int = 0,
1501
1502
    ) -> None:
        request_id = str(next(self.request_counter))
1503
1504
        self.llm_engine.add_request(
            request_id,
1505
            prompt,
1506
1507
            params,
            lora_request=lora_request,
1508
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1509
            prompt_adapter_request=prompt_adapter_request,
1510
            priority=priority,
nunjunj's avatar
nunjunj committed
1511
        )
1512

1513
    def _add_guided_params(
1514
1515
1516
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1517
1518
1519
1520
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1521
            raise ValueError("Cannot set both guided_options_request and "
1522
1523
1524
1525
1526
1527
1528
1529
1530
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1531
1532
1533
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1534
1535
        return params

1536
    def _run_engine(
1537
1538
1539
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1540
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1541
1542
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1543
            num_requests = self.llm_engine.get_num_unfinished_requests()
1544
1545
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1546
1547
1548
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1549
1550
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1551
            )
1552

Zhuohan Li's avatar
Zhuohan Li committed
1553
        # Run the engine.
1554
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1555
1556
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1557
1558
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1559
            for output in step_outputs:
1560
                if output.finished:
1561
1562
                    outputs.append(output)
                    if use_tqdm:
1563
1564
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1565
                            n = len(output.outputs)
1566
                            assert output.prompt_token_ids is not None
1567
                            total_in_toks += len(output.prompt_token_ids) * n
1568
1569
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1570
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1571
1572
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1573
1574
1575
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1576
                            pbar.update(n)
1577
1578
                        else:
                            pbar.update(1)
1579
1580
                        if pbar.n == num_requests:
                            pbar.refresh()
1581

1582
1583
        if use_tqdm:
            pbar.close()
1584
1585
1586
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1587
        return sorted(outputs, key=lambda x: int(x.request_id))