llm.py 77.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
18
19
                              BeamSearchSequence,
                              create_sort_beams_key_function)
20
21
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
22
23
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
24
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
25
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
26
                                         ChatTemplateContentFormatOption,
27
28
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
29
30
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
31
32
33
34
35
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
                                          get_score_prompt)
36
from vllm.entrypoints.utils import _validate_truncation_size
37
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
38
from vllm.inputs.parse import parse_and_batch_prompt
39
from vllm.logger import init_logger
40
from vllm.lora.request import LoRARequest
41
42
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
45
46
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
47
from vllm.pooling_params import PoolingParams, PoolingTask
48
from vllm.prompt_adapter.request import PromptAdapterRequest
49
50
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
51
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
52
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
53
from vllm.usage.usage_lib import UsageContext
54
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
55

56
57
58
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

59
60
logger = init_logger(__name__)

61
62
_R = TypeVar("_R", default=Any)

63
64

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
65
66
67
68
69
70
71
72
73
74
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
75
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
76
77
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
78
79
80
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
81
82
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
83
84
85
86
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
94
        quantization: The method used to quantize the model weights. Currently,
95
            we support "awq", "gptq", and "fp8" (experimental).
96
97
98
99
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
100
101
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
102
103
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
104
105
106
107
108
109
110
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
111
112
113
114
115
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
116
117
118
119
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
120
121
122
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
123
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
124
            When a sequence has context length larger than this, we fall back
125
126
127
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
128
129
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
130
131
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
132
        hf_token: The token to use as HTTP bearer authorization for remote files
133
            . If `True`, will use the token generated when running
134
            `huggingface-cli login` (stored in `~/.huggingface`).
135
136
137
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
138
139
140
141
142
143
144
145
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
146
147
148
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
149
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
150

151
152
    Note:
        This class is intended to be used for offline inference. For online
153
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
154
    """
155

156
    DEPRECATE_LEGACY: ClassVar[bool] = True
157
158
159
160
161
162
163
164
165
166
167
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

168
169
170
    def __init__(
        self,
        model: str,
171
172
        *,
        task: TaskOption = "auto",
173
        tokenizer: Optional[str] = None,
174
        tokenizer_mode: TokenizerMode = "auto",
175
        skip_tokenizer_init: bool = False,
176
        trust_remote_code: bool = False,
177
        allowed_local_media_path: str = "",
178
        tensor_parallel_size: int = 1,
179
180
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
181
        revision: Optional[str] = None,
182
        tokenizer_revision: Optional[str] = None,
183
        seed: Optional[int] = None,
184
        gpu_memory_utilization: float = 0.9,
185
        swap_space: float = 4,
186
        cpu_offload_gb: float = 0,
187
        enforce_eager: bool = False,
188
        max_seq_len_to_capture: int = 8192,
189
        disable_custom_all_reduce: bool = False,
190
        disable_async_output_proc: bool = False,
191
        hf_token: Optional[Union[bool, str]] = None,
192
        hf_overrides: Optional[HfOverrides] = None,
193
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
194
        override_pooler_config: Optional[PoolerConfig] = None,
195
196
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
197
198
        **kwargs,
    ) -> None:
199
        """LLM constructor."""
200

201
202
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
203

204
205
206
207
208
209
210
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

228
229
230
        if hf_overrides is None:
            hf_overrides = {}

231
        if compilation_config is not None:
232
233
234
235
236
237
238
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
239
240
            else:
                compilation_config_instance = compilation_config
241
        else:
242
            compilation_config_instance = CompilationConfig()
243

Zhuohan Li's avatar
Zhuohan Li committed
244
        engine_args = EngineArgs(
245
            model=model,
246
            task=task,
247
            tokenizer=tokenizer,
248
            tokenizer_mode=tokenizer_mode,
249
            skip_tokenizer_init=skip_tokenizer_init,
250
            trust_remote_code=trust_remote_code,
251
            allowed_local_media_path=allowed_local_media_path,
252
253
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
254
            quantization=quantization,
255
            revision=revision,
256
            tokenizer_revision=tokenizer_revision,
257
258
259
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
260
            cpu_offload_gb=cpu_offload_gb,
261
            enforce_eager=enforce_eager,
262
            max_seq_len_to_capture=max_seq_len_to_capture,
263
            disable_custom_all_reduce=disable_custom_all_reduce,
264
            disable_async_output_proc=disable_async_output_proc,
265
            hf_token=hf_token,
266
            hf_overrides=hf_overrides,
267
            mm_processor_kwargs=mm_processor_kwargs,
268
            override_pooler_config=override_pooler_config,
269
            compilation_config=compilation_config_instance,
270
271
            **kwargs,
        )
272
273
274
275
276

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
277

278
        self.request_counter = Counter()
279
        self.default_sampling_params: Union[dict[str, Any], None] = None
280

281
282
283
284
285
286
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
287
288

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
289
        tokenizer_group = self.llm_engine.get_tokenizer_group()
290

291
292
293
294
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
295
            tokenizer_group.tokenizer = tokenizer
296
        else:
297
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
298

299
    def get_default_sampling_params(self) -> SamplingParams:
300
301
302
303
304
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
305
306
        return SamplingParams()

307
308
309
310
311
312
313
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
314
        *,
315
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
316
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
317
318
319
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
320
    ) -> list[RequestOutput]:
321
322
        ...

323
    @overload  # LEGACY: single (prompt + optional token ids)
324
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
325
326
327
328
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
329
330
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
331
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
332
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
333
334
335
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
336
    ) -> list[RequestOutput]:
337
338
339
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
340
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
341
342
    def generate(
        self,
343
        prompts: list[str],
344
        sampling_params: Optional[Union[SamplingParams,
345
346
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
347
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
348
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
349
350
351
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
352
    ) -> list[RequestOutput]:
353
354
355
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
356
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
357
358
359
360
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
361
                                        list[SamplingParams]]] = None,
362
        *,
363
        prompt_token_ids: list[int],
364
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
365
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
366
367
368
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
369
    ) -> list[RequestOutput]:
370
371
372
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
373
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
374
375
    def generate(
        self,
376
        prompts: Optional[list[str]] = None,
377
        sampling_params: Optional[Union[SamplingParams,
378
                                        list[SamplingParams]]] = None,
379
        *,
380
        prompt_token_ids: list[list[int]],
381
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
382
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
383
384
385
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
386
    ) -> list[RequestOutput]:
387
388
389
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
390
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
391
392
393
394
    def generate(
        self,
        prompts: None,
        sampling_params: None,
395
        prompt_token_ids: Union[list[int], list[list[int]]],
396
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
397
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
398
399
400
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
401
    ) -> list[RequestOutput]:
402
403
        ...

nunjunj's avatar
nunjunj committed
404
405
406
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
407
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
408
    )
409
410
    def generate(
        self,
411
        prompts: Union[Union[PromptType, Sequence[PromptType]],
412
                       Optional[Union[str, list[str]]]] = None,
413
414
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
415
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
416
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
417
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
418
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
419
        guided_options_request: Optional[Union[LLMGuidedOptions,
420
                                               GuidedDecodingRequest]] = None,
421
422
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
423
424
        """Generates the completions for the input prompts.

425
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
426
427
428
429
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
430
            prompts: The prompts to the LLM. You may pass a sequence of prompts
431
                for batch inference. See [PromptType][vllm.inputs.PromptType]
432
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
433
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
434
435
436
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
437
                prompts and it is paired one by one with the prompt.
438
439
440
441
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
442
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
443
            prompt_adapter_request: Prompt Adapter request to use for
444
                generation, if any.
445
446
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
447
448

        Returns:
449
            A list of `RequestOutput` objects containing the
450
            generated completions in the same order as the input prompts.
451

452
453
454
455
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
456
        """
457
458
459
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
460
            messages = [
461
                "LLM.generate() is only supported for generative models."
462
463
            ]

464
            if "generate" in model_config.supported_runner_types:
465
                messages.append(
466
467
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
468
469
                    "Please initialize vLLM using `--task generate` or "
                    "`--task transcription`.")
470
471

            raise ValueError(" ".join(messages))
472

473
        if prompt_token_ids is not None:
474
            parsed_prompts = self._convert_v1_inputs(
475
                prompts=cast(Optional[Union[str, list[str]]], prompts),
476
477
478
                prompt_token_ids=prompt_token_ids,
            )
        else:
479
480
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
481

482
483
484
485
486
487
488
489
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

490
491
        if sampling_params is None:
            # Use default sampling params.
492
            sampling_params = self.get_default_sampling_params()
493

494
495
496
497
498
499
500
        tokenization_kwargs: dict[str, Any] = {}
        truncate_prompt_tokens = None
        if isinstance(sampling_params, SamplingParams):
            truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

501
502
503
504
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
            parsed_prompts, lora_request)

505
        self._validate_and_add_requests(
506
            prompts=parsed_prompts,
507
            params=sampling_params,
508
            use_tqdm=use_tqdm,
509
            lora_request=lora_request,
510
            prompt_adapter_request=prompt_adapter_request,
511
            guided_options=guided_options_request,
512
            tokenization_kwargs=tokenization_kwargs,
513
514
            priority=priority,
        )
515

516
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
517
        return self.engine_class.validate_outputs(outputs, RequestOutput)
518

519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
    def _get_modality_specific_lora_reqs(
            self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

        if not isinstance(parsed_prompts, Sequence):
            parsed_prompts = [parsed_prompts]

        optional_loras = ([lora_request] * len(parsed_prompts)
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
                parsed_prompt,
                opt_lora_req,
                lora_config.default_mm_loras,
            ) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
                                                     optional_loras)
        ]

    def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
        if (not default_mm_loras or not isinstance(parsed_prompt, dict)
                or "multi_modal_data" not in parsed_prompt):
            return lora_request

        parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)

        intersection = set(
            parsed_prompt["multi_modal_data"].keys()).intersection(
                default_mm_loras.keys())
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

596
    def collective_rpc(self,
597
                       method: Union[str, Callable[..., _R]],
598
                       timeout: Optional[float] = None,
599
600
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
601
602
603
604
605
606
607
608
609
610
611
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
612
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
613
614
615
616
617
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
618

619
620
621
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
622
        """
623
624

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
625
626

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
627
        """
628
629
        Run a function directly on the model inside each worker,
        returning the result for each of them.
630
        """
631
632
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
633

634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

650
651
    def beam_search(
        self,
652
        prompts: list[Union[TokensPrompt, TextPrompt]],
653
        params: BeamSearchParams,
654
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
655
        use_tqdm: bool = False,
656
    ) -> list[BeamSearchOutput]:
657
658
659
660
661
662
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
663
            params: The beam search parameters.
664
            lora_request: LoRA request to use for generation, if any.
665
            use_tqdm: Whether to use tqdm to display the progress bar.
666
        """
667
668
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
669
670
671
672
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
673
674
        length_penalty = params.length_penalty

675
676
677
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

678
679
680
681
682
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
683

684
685
686
687
688
689
690
691
692
693
694
695
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
696

697
698
699
700
701
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
702
                                            temperature=temperature)
703
        instances: list[BeamSearchInstance] = []
704

705
        for lora_req, prompt in zip(lora_requests, prompts):
706
707
708
709
710
711
712
713
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

714
715
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
716
717
718
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
719

720
            instances.append(
721
722
723
724
725
726
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
727

728
729
730
731
732
733
734
735
736
737
738
739
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
740
            all_beams: list[BeamSearchSequence] = list(
741
742
743
744
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
745
            instance_start_and_end: list[tuple[int, int]] = list(
746
747
748
749
750
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

751
752
753
754
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
755
756
757
758
759

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
760
761
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
778
                                logprobs=current_beam.logprobs + [logprobs],
779
                                lora_request=current_beam.lora_request,
780
                                cum_logprob=current_beam.cum_logprob +
781
782
783
784
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
785
786
787
788
789
790
791

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
792
                                      key=sort_beams_key,
793
794
795
796
797
798
799
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
800
                                      key=sort_beams_key,
801
802
803
804
805
806
807
808
809
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
810
811
    def chat(
        self,
812
813
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
814
        sampling_params: Optional[Union[SamplingParams,
815
                                        list[SamplingParams]]] = None,
816
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
817
818
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
819
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
820
        add_generation_prompt: bool = True,
821
        continue_final_message: bool = False,
822
        tools: Optional[list[dict[str, Any]]] = None,
823
        chat_template_kwargs: Optional[dict[str, Any]] = None,
824
825
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
826
        """
827
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
828

829
        The chat conversation is converted into a text prompt using the
830
        tokenizer and calls the [generate][] method to generate the
831
832
833
834
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
835
836

        Args:
837
838
            messages: A list of conversations or a single conversation.

839
840
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
841

nunjunj's avatar
nunjunj committed
842
843
844
845
846
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
847
848
849
850
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
851
852
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
853
                If not provided, the model's default chat template will be used.
854
855
            chat_template_content_format: The format to render message content.

856
857
858
859
860
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
861

862
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
863
                to each message.
864
            continue_final_message: If True, continues the final message in
865
                the conversation instead of starting a new one. Cannot be
866
                `True` if `add_generation_prompt` is also `True`.
867
868
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
869
870
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
871
872

        Returns:
873
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
874
875
            responses in the same order as the input messages.
        """
876
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
877

878
879
        # Handle multi and single conversations
        if is_list_of(messages, list):
880
881
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
882
                                    messages)
883
        else:
884
            # messages is list[...]
885
            list_of_messages = [
886
                cast(list[ChatCompletionMessageParam], messages)
887
            ]
888

889
        tokenizer = self.get_tokenizer(lora_request)
890
891
892
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
893
            tools,
894
895
            chat_template_content_format,
            tokenizer,
896
            model_config=model_config,
897
898
        )

899
900
901
902
903
904
905
906
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

907
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
908
909

        for msgs in list_of_messages:
910
911
912
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
913
            conversation, mm_data = parse_chat_messages(
914
915
916
917
918
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
919
920

            if isinstance(tokenizer, MistralTokenizer):
921
                prompt_token_ids = apply_mistral_chat_template(
922
923
                    tokenizer,
                    messages=msgs,
924
                    **_chat_template_kwargs,
925
926
                )
            else:
927
                prompt_str = apply_hf_chat_template(
928
                    tokenizer=tokenizer,
929
                    conversation=conversation,
930
                    model_config=model_config,
931
                    **_chat_template_kwargs,
932
                )
933
934
935
936
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
937

938
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
939
940
941
942

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

943
944
945
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

946
            prompts.append(prompt)
947

nunjunj's avatar
nunjunj committed
948
        return self.generate(
949
            prompts,
950
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
951
952
953
954
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

955
956
957
958
959
960
961
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
962
        *,
963
        truncate_prompt_tokens: Optional[int] = None,
964
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
965
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
966
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
967
        pooling_task: PoolingTask = "encode",
968
        tokenization_kwargs: Optional[dict[str, Any]] = None,
969
    ) -> list[PoolingRequestOutput]:
970
971
        ...

972
    @overload  # LEGACY: single (prompt + optional token ids)
973
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
974
975
976
977
978
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
979
        prompt_token_ids: Optional[list[int]] = None,
980
        truncate_prompt_tokens: Optional[int] = None,
981
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
982
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
983
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
984
        pooling_task: PoolingTask = "encode",
985
        tokenization_kwargs: Optional[dict[str, Any]] = None,
986
    ) -> list[PoolingRequestOutput]:
987
        ...
988

989
    @overload  # LEGACY: multi (prompt + optional token ids)
990
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
991
992
    def encode(
        self,
993
        prompts: list[str],
994
        pooling_params: Optional[Union[PoolingParams,
995
                                       Sequence[PoolingParams]]] = None,
996
        prompt_token_ids: Optional[list[list[int]]] = None,
997
        truncate_prompt_tokens: Optional[int] = None,
998
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
999
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1000
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1001
        pooling_task: PoolingTask = "encode",
1002
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1003
    ) -> list[PoolingRequestOutput]:
1004
1005
1006
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
1007
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1008
1009
1010
1011
1012
1013
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1014
        prompt_token_ids: list[int],
1015
        truncate_prompt_tokens: Optional[int] = None,
1016
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1017
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1018
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1019
        pooling_task: PoolingTask = "encode",
1020
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1021
    ) -> list[PoolingRequestOutput]:
1022
1023
1024
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
1025
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1026
1027
    def encode(
        self,
1028
        prompts: Optional[list[str]] = None,
1029
1030
1031
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1032
        prompt_token_ids: list[list[int]],
1033
        truncate_prompt_tokens: Optional[int] = None,
1034
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1035
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1036
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1037
        pooling_task: PoolingTask = "encode",
1038
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1039
    ) -> list[PoolingRequestOutput]:
1040
1041
1042
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
1043
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1044
1045
1046
1047
    def encode(
        self,
        prompts: None,
        pooling_params: None,
1048
        prompt_token_ids: Union[list[int], list[list[int]]],
1049
        truncate_prompt_tokens: Optional[int] = None,
1050
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1051
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1052
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1053
        pooling_task: PoolingTask = "encode",
1054
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1055
    ) -> list[PoolingRequestOutput]:
1056
1057
        ...

nunjunj's avatar
nunjunj committed
1058
1059
1060
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
1061
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
1062
    )
1063
1064
    def encode(
        self,
1065
        prompts: Union[Union[PromptType, Sequence[PromptType]],
1066
                       Optional[Union[str, list[str]]]] = None,
1067
1068
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1069
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
1070
        truncate_prompt_tokens: Optional[int] = None,
1071
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1072
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1073
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1074
        pooling_task: PoolingTask = "encode",
1075
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1076
    ) -> list[PoolingRequestOutput]:
1077
1078
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1079

1080
        This class automatically batches the given prompts, considering
1081
1082
1083
1084
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1085
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1086
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1087
                for more details about the format of each prompts.
1088
1089
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1090
1091
1092
1093
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1094
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
1095
            prompt_adapter_request: Prompt Adapter request to use for
1096
                generation, if any.
1097
            pooling_task: Override the pooling task to use.
1098
1099

        Returns:
1100
            A list of `PoolingRequestOutput` objects containing the
1101
            pooled hidden states in the same order as the input prompts.
1102

1103
1104
1105
1106
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1107
        """
1108
1109
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1110
1111
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
1112

1113
            if "pooling" in model_config.supported_runner_types:
1114
                messages.append(
1115
1116
1117
1118
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1119
1120

            raise ValueError(" ".join(messages))
1121

1122
        if prompt_token_ids is not None:
1123
            parsed_prompts = self._convert_v1_inputs(
1124
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1125
1126
1127
                prompt_token_ids=prompt_token_ids,
            )
        else:
1128
1129
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1130

1131
1132
1133
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1134
1135
1136

        if isinstance(pooling_params, PoolingParams):
            pooling_params.verify(pooling_task, model_config)
1137
1138
        else:
            for pooling_param in pooling_params:
1139
                pooling_param.verify(pooling_task, model_config)
1140

1141
1142
1143
1144
1145
        if tokenization_kwargs is None:
            tokenization_kwargs = dict[str, Any]()
            _validate_truncation_size(model_config.max_model_len,
                                      truncate_prompt_tokens,
                                      tokenization_kwargs)
1146

1147
        self._validate_and_add_requests(
1148
            prompts=parsed_prompts,
1149
            params=pooling_params,
1150
            use_tqdm=use_tqdm,
1151
            lora_request=lora_request,
1152
            tokenization_kwargs=tokenization_kwargs,
1153
            prompt_adapter_request=prompt_adapter_request,
1154
1155
        )

1156
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1157
        return self.engine_class.validate_outputs(outputs,
1158
                                                  PoolingRequestOutput)
1159

1160
1161
1162
1163
1164
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1165
        truncate_prompt_tokens: Optional[int] = None,
1166
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1167
1168
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1169
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1170
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1171
    ) -> list[EmbeddingRequestOutput]:
1172
1173
1174
1175
1176
1177
1178
1179
1180
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1181
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1182
                for more details about the format of each prompts.
1183
1184
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1185
1186
1187
1188
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1189
1190
1191
1192
1193
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1194
            A list of `EmbeddingRequestOutput` objects containing the
1195
1196
            embedding vectors in the same order as the input prompts.
        """
1197
1198
1199
1200
        model_config = self.llm_engine.model_config
        if "embed" not in model_config.supported_tasks:
            raise ValueError("Embedding API is not supported by this model. "
                             "Please set `--task embed`.")
1201

1202
1203
1204
1205
1206
1207
1208
1209
1210
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
            pooling_task="embed",
        )
1211
1212
1213
1214
1215
1216
1217
1218

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1219
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1220
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1221
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1222
    ) -> list[ClassificationRequestOutput]:
1223
1224
1225
1226
1227
1228
1229
1230
1231
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1232
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1233
                for more details about the format of each prompts.
1234
1235
1236
1237
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1238
1239
1240
1241
1242
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1243
            A list of `ClassificationRequestOutput` objects containing the
1244
1245
            embedding vectors in the same order as the input prompts.
        """
1246
1247
        model_config = self.llm_engine.model_config
        if "classify" not in model_config.supported_tasks:
1248
            raise ValueError(
1249
1250
                "Classification API is not supported by this model. "
                "Please set `--task classify`.")
1251

1252
1253
1254
1255
1256
1257
1258
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
            pooling_task="classify",
        )
1259
1260
1261

        return [ClassificationRequestOutput.from_base(item) for item in items]

1262
1263
1264
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1265
1266
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1267
        truncate_prompt_tokens: Optional[int] = None,
1268
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1269
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1270
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1271
    ) -> list[ScoringRequestOutput]:
1272

1273
        encoded_output: list[PoolingRequestOutput] = self.encode(
1274
            text_1 + text_2,
1275
            truncate_prompt_tokens=truncate_prompt_tokens,
1276
1277
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1278
1279
1280
            prompt_adapter_request=prompt_adapter_request,
            pooling_task="embed",
        )
1281

1282
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1283
            0:len(text_1)]
1284
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1285
            len(text_1):]
1286
1287
1288
1289

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1290
1291
1292
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1293
1294
1295
1296
1297
1298
1299

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1300
        tokenizer: AnyTokenizer,
1301
1302
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1303
        truncate_prompt_tokens: Optional[int] = None,
1304
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1305
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1306
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1307
    ) -> list[ScoringRequestOutput]:
1308
1309
1310
1311
1312

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

1313
1314
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1315

1316
        pooling_params = PoolingParams(task="score")
1317
        tokenization_kwargs: dict[str, Any] = {}
1318
1319
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1320
1321
1322

        parsed_prompts = []

1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

        if self.llm_engine.model_config.is_multimodal_model:

            model_config = self.llm_engine.model_config

            for q, d in input_pairs:
                _, engine_prompt = get_score_prompt(
                    model_config=model_config,
                    data_1=q,
                    data_2=d,
                    tokenizer=tokenizer,
                    tokenization_kwargs=tokenization_kwargs,
                )

                parsed_prompts.append(engine_prompt)

        else:

            for q, t in input_pairs:
                if self.llm_engine.model_config.use_pad_token:
                    # cross_encoder models defaults to using pad_token.
                    prompt_inputs = tokenizer(
                        text=q,  # type: ignore[arg-type]
                        text_pair=t,  # type: ignore[arg-type]
                        **tokenization_kwargs)
                else:
                    # `llm as reranker` models defaults to not using pad_token.
                    prompt_inputs = tokenizer(
                        text=q + t,  # type: ignore[operator]
                        **tokenization_kwargs)
                engine_prompt = TokensPrompt(
                    prompt_token_ids=prompt_inputs["input_ids"],
                    token_type_ids=prompt_inputs.get("token_type_ids"))
                parsed_prompts.append(engine_prompt)
1358
1359
1360
1361

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1362
            use_tqdm=use_tqdm,
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1373
1374
    def score(
        self,
1375
1376
1377
1378
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1379
        /,
1380
        *,
1381
        truncate_prompt_tokens: Optional[int] = None,
1382
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1383
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1384
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1385
    ) -> list[ScoringRequestOutput]:
1386
1387
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1388

1389
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1390
1391
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1392
        The input pairs are used to build a list of prompts for the
1393
1394
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1395
1396
1397
1398
1399
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
        appropriate multi-modal models. For multi-modal inputs, ensure the 
        prompt structure matches the model's expected input format.
1400
1401

        Args:
1402
1403
1404
1405
1406
1407
1408
1409
            data_1: Can be a single prompt, a list of prompts or 
                `ScoreMultiModalParam`, which can contain either text or 
                multi-modal data. When a list, it must have the same length as 
                the `data_2` list.
            data_2: The data to pair with the query to form the input to 
                the LLM. Can be text or multi-modal data. See [PromptType]
                [vllm.inputs.PromptType] for more details about the format of 
                each prompt.
1410
1411
1412
1413
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1414
1415
1416
1417
1418
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1419
            A list of `ScoringRequestOutput` objects containing the
1420
1421
            generated scores in the same order as the input prompts.
        """
1422
1423
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1424
1425
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1426

1427
            if "pooling" in model_config.supported_runner_types:
1428
                messages.append(
1429
1430
1431
1432
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1433
1434
1435

            raise ValueError(" ".join(messages))

1436
1437
1438
1439
        if all(t not in model_config.supported_tasks
               for t in ("embed", "classify")):
            raise ValueError("Score API is not supported by this model. "
                             "Please set `--task embed` or `--task classify`.")
1440

1441
1442
        if (model_config.task == "classify"
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1443
            raise ValueError("Score API is only enabled for num_labels == 1.")
1444
1445
1446
1447

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1448
        tokenizer = self.get_tokenizer()
1449

1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
        if not self.llm_engine.model_config.is_multimodal_model:

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
                    raise ValueError(
                        f"ScoreMultiModalParam is not supported for {self.llm_engine.model_config.architecture}",  # noqa: E501
                    )

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1499

1500
        if self.llm_engine.model_config.is_cross_encoder:
1501
1502
1503
1504
1505
1506
1507
1508
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1509
        else:
1510
1511
            return self._embedding_score(
                tokenizer,
1512
1513
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1514
1515
1516
1517
                truncate_prompt_tokens,
                use_tqdm,
                lora_request,
                prompt_adapter_request)
1518

1519
1520
1521
1522
1523
1524
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1525
1526
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1527

1528
1529
1530
1531
1532
1533
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1534
        Args:
1535
1536
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1537
                is forgotten. Level 1 sleep is good for sleeping and waking
1538
1539
1540
1541
1542
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1543
                sleep is good for sleeping and waking up the engine to run a
1544
                different model or update the model, where previous model
1545
                weights are not needed. It reduces CPU memory pressure.
1546
        """
1547
        self.reset_prefix_cache()
1548
1549
        self.llm_engine.sleep(level=level)

1550
    def wake_up(self, tags: Optional[list[str]] = None):
1551
        """
1552
        Wake up the engine from sleep mode. See the [sleep][] method
1553
        for more details.
1554

1555
        Args:
1556
1557
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1558
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1559
                wake_up should be called with all tags (or None) before the
1560
1561
1562
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1563

1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1578
1579
    # LEGACY
    def _convert_v1_inputs(
1580
        self,
1581
1582
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1583
1584
    ):
        # skip_tokenizer_init is now checked in engine
1585

1586
1587
1588
1589
1590
1591
1592
1593
1594
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1595
1596
1597
1598
1599
1600
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1601
1602
        if prompts is not None:
            num_requests = len(prompts)
1603
        elif prompt_token_ids is not None:
1604
            num_requests = len(prompt_token_ids)
1605
        parsed_prompts: list[PromptType] = []
1606
        for i in range(num_requests):
1607
            item: PromptType
1608

1609
            if prompts is not None:
1610
1611
1612
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1613
            else:
1614
                raise AssertionError
1615

1616
            parsed_prompts.append(item)
1617

1618
        return parsed_prompts
1619
1620
1621

    def _validate_and_add_requests(
        self,
1622
        prompts: Union[PromptType, Sequence[PromptType]],
1623
1624
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1625
        *,
1626
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1627
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1628
        prompt_adapter_request: Optional[PromptAdapterRequest],
1629
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1630
        guided_options: Optional[GuidedDecodingRequest] = None,
1631
        priority: Optional[list[int]] = None,
1632
    ) -> None:
1633
1634
1635
1636
1637
1638
1639
1640
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1641
        if isinstance(prompts, (str, dict)):
1642
            # Convert a single prompt to a list.
1643
            prompts = [prompts]
1644

1645
        num_requests = len(prompts)
1646
        if isinstance(params, Sequence) and len(params) != num_requests:
1647
            raise ValueError("The lengths of prompts and params "
1648
                             "must be the same.")
1649
        if isinstance(lora_request,
1650
                      Sequence) and len(lora_request) != num_requests:
1651
1652
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1653

1654
        for sp in params if isinstance(params, Sequence) else (params, ):
1655
            if isinstance(sp, SamplingParams):
1656
                self._add_guided_params(sp, guided_options)
1657
1658
1659

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1660

Zhuohan Li's avatar
Zhuohan Li committed
1661
        # Add requests to the engine.
1662
1663
        it = prompts
        if use_tqdm:
1664
1665
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1666
1667

        for i, prompt in enumerate(it):
1668
            self._add_request(
1669
                prompt,
1670
                params[i] if isinstance(params, Sequence) else params,
1671
                tokenization_kwargs=tokenization_kwargs,
1672
1673
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1674
                prompt_adapter_request=prompt_adapter_request,
1675
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1676
            )
1677

1678
    def _add_request(
nunjunj's avatar
nunjunj committed
1679
        self,
1680
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1681
        params: Union[SamplingParams, PoolingParams],
1682
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1683
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1684
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1685
        priority: int = 0,
1686
1687
    ) -> None:
        request_id = str(next(self.request_counter))
1688
1689
        self.llm_engine.add_request(
            request_id,
1690
            prompt,
1691
1692
            params,
            lora_request=lora_request,
1693
            tokenization_kwargs=tokenization_kwargs,
nunjunj's avatar
nunjunj committed
1694
            prompt_adapter_request=prompt_adapter_request,
1695
            priority=priority,
nunjunj's avatar
nunjunj committed
1696
        )
1697

1698
    def _add_guided_params(
1699
1700
1701
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1702
1703
1704
1705
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1706
            raise ValueError("Cannot set both guided_options_request and "
1707
1708
1709
1710
1711
1712
1713
1714
1715
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1716
1717
1718
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1719
1720
        return params

1721
    def _run_engine(
1722
1723
1724
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1725
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1726
1727
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1728
            num_requests = self.llm_engine.get_num_unfinished_requests()
1729
1730
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1731
1732
1733
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1734
1735
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1736
            )
1737

Zhuohan Li's avatar
Zhuohan Li committed
1738
        # Run the engine.
1739
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1740
1741
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1742
1743
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1744
            for output in step_outputs:
1745
                if output.finished:
1746
1747
                    outputs.append(output)
                    if use_tqdm:
1748
1749
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1750
                            n = len(output.outputs)
1751
                            assert output.prompt_token_ids is not None
1752
                            total_in_toks += len(output.prompt_token_ids) * n
1753
1754
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1755
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1756
1757
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1758
1759
1760
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1761
                            pbar.update(n)
1762
1763
                        else:
                            pbar.update(1)
1764
1765
                        if pbar.n == num_requests:
                            pbar.refresh()
1766

1767
1768
        if use_tqdm:
            pbar.close()
1769
1770
1771
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1772
        return sorted(outputs, key=lambda x: int(x.request_id))