"vllm/vscode:/vscode.git/clone" did not exist on "2ae25f79cf1e8d21f7bcba097e4c039463c22be4"
llm.py 75 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
18
19
                              BeamSearchSequence,
                              create_sort_beams_key_function)
20
21
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
22
23
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
24
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
25
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
26
                                         ChatTemplateContentFormatOption,
27
28
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
29
30
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
31
32
33
34
35
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
                                          get_score_prompt)
36
from vllm.entrypoints.utils import _validate_truncation_size
37
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
38
from vllm.inputs.parse import parse_and_batch_prompt
39
from vllm.logger import init_logger
40
from vllm.lora.request import LoRARequest
41
42
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
43
from vllm.model_executor.layers.quantization import QuantizationMethods
44
45
46
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
47
from vllm.pooling_params import PoolingParams, PoolingTask
48
49
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
50
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
51
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
52
from vllm.usage.usage_lib import UsageContext
53
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
54

55
56
57
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

58
59
logger = init_logger(__name__)

60
61
_R = TypeVar("_R", default=Any)

62
63

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
68
69
70
71
72
73
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
74
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
75
76
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
77
78
79
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
80
81
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
82
83
84
85
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
86
87
88
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
92
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
93
        quantization: The method used to quantize the model weights. Currently,
94
            we support "awq", "gptq", and "fp8" (experimental).
95
96
97
98
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
99
100
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
101
102
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
103
104
105
106
107
108
109
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
110
111
112
113
114
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
115
116
117
118
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
119
120
121
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
122
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
123
            When a sequence has context length larger than this, we fall back
124
125
126
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
127
128
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
129
130
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
131
        hf_token: The token to use as HTTP bearer authorization for remote files
132
            . If `True`, will use the token generated when running
133
            `huggingface-cli login` (stored in `~/.huggingface`).
134
135
136
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
137
138
139
140
141
142
143
144
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
145
146
147
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
148
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
149

150
151
    Note:
        This class is intended to be used for offline inference. For online
152
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
153
    """
154

155
    DEPRECATE_LEGACY: ClassVar[bool] = True
156
157
158
159
160
161
162
163
164
165
166
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

167
168
169
    def __init__(
        self,
        model: str,
170
171
        *,
        task: TaskOption = "auto",
172
        tokenizer: Optional[str] = None,
173
        tokenizer_mode: TokenizerMode = "auto",
174
        skip_tokenizer_init: bool = False,
175
        trust_remote_code: bool = False,
176
        allowed_local_media_path: str = "",
177
        tensor_parallel_size: int = 1,
178
179
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
180
        revision: Optional[str] = None,
181
        tokenizer_revision: Optional[str] = None,
182
        seed: Optional[int] = None,
183
        gpu_memory_utilization: float = 0.9,
184
        swap_space: float = 4,
185
        cpu_offload_gb: float = 0,
186
        enforce_eager: bool = False,
187
        max_seq_len_to_capture: int = 8192,
188
        disable_custom_all_reduce: bool = False,
189
        disable_async_output_proc: bool = False,
190
        hf_token: Optional[Union[bool, str]] = None,
191
        hf_overrides: Optional[HfOverrides] = None,
192
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
193
        override_pooler_config: Optional[PoolerConfig] = None,
194
195
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
196
197
        **kwargs,
    ) -> None:
198
        """LLM constructor."""
199

200
201
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
202

203
204
205
206
207
208
209
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

227
228
229
        if hf_overrides is None:
            hf_overrides = {}

230
        if compilation_config is not None:
231
232
233
234
235
236
237
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
238
239
            else:
                compilation_config_instance = compilation_config
240
        else:
241
            compilation_config_instance = CompilationConfig()
242

Zhuohan Li's avatar
Zhuohan Li committed
243
        engine_args = EngineArgs(
244
            model=model,
245
            task=task,
246
            tokenizer=tokenizer,
247
            tokenizer_mode=tokenizer_mode,
248
            skip_tokenizer_init=skip_tokenizer_init,
249
            trust_remote_code=trust_remote_code,
250
            allowed_local_media_path=allowed_local_media_path,
251
252
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
253
            quantization=quantization,
254
            revision=revision,
255
            tokenizer_revision=tokenizer_revision,
256
257
258
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
259
            cpu_offload_gb=cpu_offload_gb,
260
            enforce_eager=enforce_eager,
261
            max_seq_len_to_capture=max_seq_len_to_capture,
262
            disable_custom_all_reduce=disable_custom_all_reduce,
263
            disable_async_output_proc=disable_async_output_proc,
264
            hf_token=hf_token,
265
            hf_overrides=hf_overrides,
266
            mm_processor_kwargs=mm_processor_kwargs,
267
            override_pooler_config=override_pooler_config,
268
            compilation_config=compilation_config_instance,
269
270
            **kwargs,
        )
271
272
273
274
275

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
276

277
        self.request_counter = Counter()
278
        self.default_sampling_params: Union[dict[str, Any], None] = None
279

280
281
282
283
284
285
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
286
287

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
288
        tokenizer_group = self.llm_engine.get_tokenizer_group()
289

290
291
292
293
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
294
            tokenizer_group.tokenizer = tokenizer
295
        else:
296
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
297

298
    def get_default_sampling_params(self) -> SamplingParams:
299
300
301
302
303
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
304
305
        return SamplingParams()

306
307
308
309
310
311
312
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
313
        *,
314
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
315
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
316
317
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
318
    ) -> list[RequestOutput]:
319
320
        ...

321
    @overload  # LEGACY: single (prompt + optional token ids)
322
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
323
324
325
326
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
327
328
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
329
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
330
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
331
332
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
333
    ) -> list[RequestOutput]:
334
335
336
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
337
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
338
339
    def generate(
        self,
340
        prompts: list[str],
341
        sampling_params: Optional[Union[SamplingParams,
342
343
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
344
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
345
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
346
347
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
348
    ) -> list[RequestOutput]:
349
350
351
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
352
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
353
354
355
356
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
357
                                        list[SamplingParams]]] = None,
358
        *,
359
        prompt_token_ids: list[int],
360
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
361
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
362
363
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
364
    ) -> list[RequestOutput]:
365
366
367
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
368
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
369
370
    def generate(
        self,
371
        prompts: Optional[list[str]] = None,
372
        sampling_params: Optional[Union[SamplingParams,
373
                                        list[SamplingParams]]] = None,
374
        *,
375
        prompt_token_ids: list[list[int]],
376
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
377
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
378
379
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
380
    ) -> list[RequestOutput]:
381
382
383
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
384
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
385
386
387
388
    def generate(
        self,
        prompts: None,
        sampling_params: None,
389
        prompt_token_ids: Union[list[int], list[list[int]]],
390
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
391
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
392
393
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
394
    ) -> list[RequestOutput]:
395
396
        ...

nunjunj's avatar
nunjunj committed
397
398
399
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
400
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
401
    )
402
403
    def generate(
        self,
404
        prompts: Union[Union[PromptType, Sequence[PromptType]],
405
                       Optional[Union[str, list[str]]]] = None,
406
407
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
408
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
409
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
410
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
411
        guided_options_request: Optional[Union[LLMGuidedOptions,
412
                                               GuidedDecodingRequest]] = None,
413
414
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
415
416
        """Generates the completions for the input prompts.

417
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
418
419
420
421
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
422
            prompts: The prompts to the LLM. You may pass a sequence of prompts
423
                for batch inference. See [PromptType][vllm.inputs.PromptType]
424
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
425
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
426
427
428
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
429
                prompts and it is paired one by one with the prompt.
430
431
432
433
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
434
            lora_request: LoRA request to use for generation, if any.
435
436
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
437
438

        Returns:
439
            A list of `RequestOutput` objects containing the
440
            generated completions in the same order as the input prompts.
441

442
443
444
445
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
446
        """
447
448
449
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
450
            messages = [
451
                "LLM.generate() is only supported for generative models."
452
453
            ]

454
            if "generate" in model_config.supported_runner_types:
455
                messages.append(
456
457
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
458
459
                    "Please initialize vLLM using `--task generate` or "
                    "`--task transcription`.")
460
461

            raise ValueError(" ".join(messages))
462

463
        if prompt_token_ids is not None:
464
            parsed_prompts = self._convert_v1_inputs(
465
                prompts=cast(Optional[Union[str, list[str]]], prompts),
466
467
468
                prompt_token_ids=prompt_token_ids,
            )
        else:
469
470
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
471

472
473
474
475
476
477
478
479
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

480
481
        if sampling_params is None:
            # Use default sampling params.
482
            sampling_params = self.get_default_sampling_params()
483

484
485
486
487
488
489
490
        tokenization_kwargs: dict[str, Any] = {}
        truncate_prompt_tokens = None
        if isinstance(sampling_params, SamplingParams):
            truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)

491
492
493
494
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
            parsed_prompts, lora_request)

495
        self._validate_and_add_requests(
496
            prompts=parsed_prompts,
497
            params=sampling_params,
498
            use_tqdm=use_tqdm,
499
            lora_request=lora_request,
500
            guided_options=guided_options_request,
501
            tokenization_kwargs=tokenization_kwargs,
502
503
            priority=priority,
        )
504

505
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
506
        return self.engine_class.validate_outputs(outputs, RequestOutput)
507

508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
    def _get_modality_specific_lora_reqs(
            self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

        if not isinstance(parsed_prompts, Sequence):
            parsed_prompts = [parsed_prompts]

        optional_loras = ([lora_request] * len(parsed_prompts)
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
                parsed_prompt,
                opt_lora_req,
                lora_config.default_mm_loras,
            ) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
                                                     optional_loras)
        ]

    def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
        if (not default_mm_loras or not isinstance(parsed_prompt, dict)
                or "multi_modal_data" not in parsed_prompt):
            return lora_request

        parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)

        intersection = set(
            parsed_prompt["multi_modal_data"].keys()).intersection(
                default_mm_loras.keys())
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

585
    def collective_rpc(self,
586
                       method: Union[str, Callable[..., _R]],
587
                       timeout: Optional[float] = None,
588
589
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
590
591
592
593
594
595
596
597
598
599
600
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
601
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
602
603
604
605
606
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
607

608
609
610
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
611
        """
612
613

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
614
615

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
616
        """
617
618
        Run a function directly on the model inside each worker,
        returning the result for each of them.
619
        """
620
621
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
622

623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

639
640
    def beam_search(
        self,
641
        prompts: list[Union[TokensPrompt, TextPrompt]],
642
        params: BeamSearchParams,
643
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
644
        use_tqdm: bool = False,
645
    ) -> list[BeamSearchOutput]:
646
647
648
649
650
651
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
652
            params: The beam search parameters.
653
            lora_request: LoRA request to use for generation, if any.
654
            use_tqdm: Whether to use tqdm to display the progress bar.
655
        """
656
657
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
658
659
660
661
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
662
663
        length_penalty = params.length_penalty

664
665
666
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

667
668
669
670
671
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
672

673
674
675
676
677
678
679
680
681
682
683
684
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
685

686
687
688
689
690
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
691
                                            temperature=temperature)
692
        instances: list[BeamSearchInstance] = []
693

694
        for lora_req, prompt in zip(lora_requests, prompts):
695
696
697
698
699
700
701
702
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

703
704
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
705
706
707
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
708

709
            instances.append(
710
711
712
713
714
715
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
716

717
718
719
720
721
722
723
724
725
726
727
728
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
729
            all_beams: list[BeamSearchSequence] = list(
730
731
732
733
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
734
            instance_start_and_end: list[tuple[int, int]] = list(
735
736
737
738
739
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

740
741
742
743
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
744
745
746
747
748

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
749
750
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
767
                                logprobs=current_beam.logprobs + [logprobs],
768
                                lora_request=current_beam.lora_request,
769
                                cum_logprob=current_beam.cum_logprob +
770
771
772
773
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
774
775
776
777
778
779
780

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
781
                                      key=sort_beams_key,
782
783
784
785
786
787
788
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
789
                                      key=sort_beams_key,
790
791
792
793
794
795
796
797
798
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
799
800
    def chat(
        self,
801
802
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
803
        sampling_params: Optional[Union[SamplingParams,
804
                                        list[SamplingParams]]] = None,
805
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
806
807
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
808
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
809
        add_generation_prompt: bool = True,
810
        continue_final_message: bool = False,
811
        tools: Optional[list[dict[str, Any]]] = None,
812
        chat_template_kwargs: Optional[dict[str, Any]] = None,
813
814
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
815
        """
816
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
817

818
        The chat conversation is converted into a text prompt using the
819
        tokenizer and calls the [generate][] method to generate the
820
821
822
823
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
824
825

        Args:
826
827
            messages: A list of conversations or a single conversation.

828
829
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
830

nunjunj's avatar
nunjunj committed
831
832
833
834
835
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
836
837
838
839
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
840
841
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
842
                If not provided, the model's default chat template will be used.
843
844
            chat_template_content_format: The format to render message content.

845
846
847
848
849
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
850

851
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
852
                to each message.
853
            continue_final_message: If True, continues the final message in
854
                the conversation instead of starting a new one. Cannot be
855
                `True` if `add_generation_prompt` is also `True`.
856
857
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
858
859
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
860
861

        Returns:
862
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
863
864
            responses in the same order as the input messages.
        """
865
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
866

867
868
        # Handle multi and single conversations
        if is_list_of(messages, list):
869
870
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
871
                                    messages)
872
        else:
873
            # messages is list[...]
874
            list_of_messages = [
875
                cast(list[ChatCompletionMessageParam], messages)
876
            ]
877

878
        tokenizer = self.get_tokenizer(lora_request)
879
880
881
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
882
            tools,
883
884
            chat_template_content_format,
            tokenizer,
885
            model_config=model_config,
886
887
        )

888
889
890
891
892
893
894
895
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

896
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
897
898

        for msgs in list_of_messages:
899
900
901
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
902
            conversation, mm_data = parse_chat_messages(
903
904
905
906
907
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
908
909

            if isinstance(tokenizer, MistralTokenizer):
910
                prompt_token_ids = apply_mistral_chat_template(
911
912
                    tokenizer,
                    messages=msgs,
913
                    **_chat_template_kwargs,
914
915
                )
            else:
916
                prompt_str = apply_hf_chat_template(
917
                    tokenizer=tokenizer,
918
                    conversation=conversation,
919
                    model_config=model_config,
920
                    **_chat_template_kwargs,
921
                )
922
923
924
925
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
926

927
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
928
929
930
931

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

932
933
934
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

935
            prompts.append(prompt)
936

nunjunj's avatar
nunjunj committed
937
        return self.generate(
938
            prompts,
939
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
940
941
942
943
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

944
945
946
947
948
949
950
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
951
        *,
952
        truncate_prompt_tokens: Optional[int] = None,
953
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
954
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
955
        pooling_task: PoolingTask = "encode",
956
        tokenization_kwargs: Optional[dict[str, Any]] = None,
957
    ) -> list[PoolingRequestOutput]:
958
959
        ...

960
    @overload  # LEGACY: single (prompt + optional token ids)
961
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
962
963
964
965
966
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
967
        prompt_token_ids: Optional[list[int]] = None,
968
        truncate_prompt_tokens: Optional[int] = None,
969
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
970
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
971
        pooling_task: PoolingTask = "encode",
972
        tokenization_kwargs: Optional[dict[str, Any]] = None,
973
    ) -> list[PoolingRequestOutput]:
974
        ...
975

976
    @overload  # LEGACY: multi (prompt + optional token ids)
977
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
978
979
    def encode(
        self,
980
        prompts: list[str],
981
        pooling_params: Optional[Union[PoolingParams,
982
                                       Sequence[PoolingParams]]] = None,
983
        prompt_token_ids: Optional[list[list[int]]] = None,
984
        truncate_prompt_tokens: Optional[int] = None,
985
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
986
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
987
        pooling_task: PoolingTask = "encode",
988
        tokenization_kwargs: Optional[dict[str, Any]] = None,
989
    ) -> list[PoolingRequestOutput]:
990
991
992
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
993
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
994
995
996
997
998
999
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1000
        prompt_token_ids: list[int],
1001
        truncate_prompt_tokens: Optional[int] = None,
1002
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1003
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1004
        pooling_task: PoolingTask = "encode",
1005
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1006
    ) -> list[PoolingRequestOutput]:
1007
1008
1009
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
1010
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1011
1012
    def encode(
        self,
1013
        prompts: Optional[list[str]] = None,
1014
1015
1016
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1017
        prompt_token_ids: list[list[int]],
1018
        truncate_prompt_tokens: Optional[int] = None,
1019
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1020
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1021
        pooling_task: PoolingTask = "encode",
1022
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1023
    ) -> list[PoolingRequestOutput]:
1024
1025
1026
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
1027
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1028
1029
1030
1031
    def encode(
        self,
        prompts: None,
        pooling_params: None,
1032
        prompt_token_ids: Union[list[int], list[list[int]]],
1033
        truncate_prompt_tokens: Optional[int] = None,
1034
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1035
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1036
        pooling_task: PoolingTask = "encode",
1037
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1038
    ) -> list[PoolingRequestOutput]:
1039
1040
        ...

nunjunj's avatar
nunjunj committed
1041
1042
1043
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
1044
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
1045
    )
1046
1047
    def encode(
        self,
1048
        prompts: Union[Union[PromptType, Sequence[PromptType]],
1049
                       Optional[Union[str, list[str]]]] = None,
1050
1051
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1052
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
1053
        truncate_prompt_tokens: Optional[int] = None,
1054
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1055
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1056
        pooling_task: PoolingTask = "encode",
1057
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1058
    ) -> list[PoolingRequestOutput]:
1059
1060
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1061

1062
        This class automatically batches the given prompts, considering
1063
1064
1065
1066
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1067
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1068
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1069
                for more details about the format of each prompts.
1070
1071
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1072
1073
1074
1075
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1076
            lora_request: LoRA request to use for generation, if any.
1077
            pooling_task: Override the pooling task to use.
1078
1079

        Returns:
1080
            A list of `PoolingRequestOutput` objects containing the
1081
            pooled hidden states in the same order as the input prompts.
1082

1083
1084
1085
1086
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1087
        """
1088
1089
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1090
1091
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
1092

1093
            if "pooling" in model_config.supported_runner_types:
1094
                messages.append(
1095
1096
1097
1098
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1099
1100

            raise ValueError(" ".join(messages))
1101

1102
        if prompt_token_ids is not None:
1103
            parsed_prompts = self._convert_v1_inputs(
1104
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1105
1106
1107
                prompt_token_ids=prompt_token_ids,
            )
        else:
1108
1109
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1110

1111
1112
1113
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1114
1115
1116

        if isinstance(pooling_params, PoolingParams):
            pooling_params.verify(pooling_task, model_config)
1117
1118
        else:
            for pooling_param in pooling_params:
1119
                pooling_param.verify(pooling_task, model_config)
1120

1121
1122
1123
1124
1125
        if tokenization_kwargs is None:
            tokenization_kwargs = dict[str, Any]()
            _validate_truncation_size(model_config.max_model_len,
                                      truncate_prompt_tokens,
                                      tokenization_kwargs)
1126

1127
        self._validate_and_add_requests(
1128
            prompts=parsed_prompts,
1129
            params=pooling_params,
1130
            use_tqdm=use_tqdm,
1131
            lora_request=lora_request,
1132
            tokenization_kwargs=tokenization_kwargs,
1133
1134
        )

1135
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1136
        return self.engine_class.validate_outputs(outputs,
1137
                                                  PoolingRequestOutput)
1138

1139
1140
1141
1142
1143
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1144
        truncate_prompt_tokens: Optional[int] = None,
1145
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1146
1147
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1148
1149
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1150
1151
1152
1153
1154
1155
1156
1157
1158
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1159
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1160
                for more details about the format of each prompts.
1161
1162
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1163
1164
1165
1166
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1167
1168
1169
            lora_request: LoRA request to use for generation, if any.

        Returns:
1170
            A list of `EmbeddingRequestOutput` objects containing the
1171
1172
            embedding vectors in the same order as the input prompts.
        """
1173
1174
1175
1176
        model_config = self.llm_engine.model_config
        if "embed" not in model_config.supported_tasks:
            raise ValueError("Embedding API is not supported by this model. "
                             "Please set `--task embed`.")
1177

1178
1179
1180
1181
1182
1183
1184
1185
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1186
1187
1188
1189
1190
1191
1192
1193

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1194
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1195
1196
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1197
1198
1199
1200
1201
1202
1203
1204
1205
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1206
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1207
                for more details about the format of each prompts.
1208
1209
1210
1211
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1212
1213
1214
            lora_request: LoRA request to use for generation, if any.

        Returns:
1215
            A list of `ClassificationRequestOutput` objects containing the
1216
1217
            embedding vectors in the same order as the input prompts.
        """
1218
1219
        model_config = self.llm_engine.model_config
        if "classify" not in model_config.supported_tasks:
1220
            raise ValueError(
1221
1222
                "Classification API is not supported by this model. "
                "Please set `--task classify`.")
1223

1224
1225
1226
1227
1228
1229
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_task="classify",
        )
1230
1231
1232

        return [ClassificationRequestOutput.from_base(item) for item in items]

1233
1234
1235
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1236
1237
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1238
        truncate_prompt_tokens: Optional[int] = None,
1239
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1240
1241
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1242

1243
        encoded_output: list[PoolingRequestOutput] = self.encode(
1244
            text_1 + text_2,
1245
            truncate_prompt_tokens=truncate_prompt_tokens,
1246
1247
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1248
1249
            pooling_task="embed",
        )
1250

1251
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1252
            0:len(text_1)]
1253
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1254
            len(text_1):]
1255
1256
1257
1258

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1259
1260
1261
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1262
1263
1264
1265
1266
1267
1268

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1269
        tokenizer: AnyTokenizer,
1270
1271
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1272
        truncate_prompt_tokens: Optional[int] = None,
1273
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1274
1275
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1276
1277
1278
1279
1280

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "Score API is only enabled for `--task embed or score`")

1281
1282
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1283

1284
        pooling_params = PoolingParams(task="score")
1285
        tokenization_kwargs: dict[str, Any] = {}
1286
1287
        _validate_truncation_size(self.llm_engine.model_config.max_model_len,
                                  truncate_prompt_tokens, tokenization_kwargs)
1288
1289
1290

        parsed_prompts = []

1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

        if self.llm_engine.model_config.is_multimodal_model:

            model_config = self.llm_engine.model_config

            for q, d in input_pairs:
                _, engine_prompt = get_score_prompt(
                    model_config=model_config,
                    data_1=q,
                    data_2=d,
                    tokenizer=tokenizer,
                    tokenization_kwargs=tokenization_kwargs,
                )

                parsed_prompts.append(engine_prompt)

        else:

            for q, t in input_pairs:
                if self.llm_engine.model_config.use_pad_token:
                    # cross_encoder models defaults to using pad_token.
                    prompt_inputs = tokenizer(
                        text=q,  # type: ignore[arg-type]
                        text_pair=t,  # type: ignore[arg-type]
                        **tokenization_kwargs)
                else:
                    # `llm as reranker` models defaults to not using pad_token.
                    prompt_inputs = tokenizer(
                        text=q + t,  # type: ignore[operator]
                        **tokenization_kwargs)
                engine_prompt = TokensPrompt(
                    prompt_token_ids=prompt_inputs["input_ids"],
                    token_type_ids=prompt_inputs.get("token_type_ids"))
                parsed_prompts.append(engine_prompt)
1326
1327
1328
1329

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1330
            use_tqdm=use_tqdm,
1331
1332
1333
1334
1335
1336
1337
1338
1339
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1340
1341
    def score(
        self,
1342
1343
1344
1345
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1346
        /,
1347
        *,
1348
        truncate_prompt_tokens: Optional[int] = None,
1349
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1350
1351
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1352
1353
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1354

1355
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1356
1357
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1358
        The input pairs are used to build a list of prompts for the
1359
1360
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1361
1362
1363
1364
1365
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
        appropriate multi-modal models. For multi-modal inputs, ensure the 
        prompt structure matches the model's expected input format.
1366
1367

        Args:
1368
1369
1370
1371
1372
1373
1374
1375
            data_1: Can be a single prompt, a list of prompts or 
                `ScoreMultiModalParam`, which can contain either text or 
                multi-modal data. When a list, it must have the same length as 
                the `data_2` list.
            data_2: The data to pair with the query to form the input to 
                the LLM. Can be text or multi-modal data. See [PromptType]
                [vllm.inputs.PromptType] for more details about the format of 
                each prompt.
1376
1377
1378
1379
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1380
1381
1382
            lora_request: LoRA request to use for generation, if any.

        Returns:
1383
            A list of `ScoringRequestOutput` objects containing the
1384
1385
            generated scores in the same order as the input prompts.
        """
1386
1387
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1388
1389
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1390

1391
            if "pooling" in model_config.supported_runner_types:
1392
                messages.append(
1393
1394
1395
1396
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1397
1398
1399

            raise ValueError(" ".join(messages))

1400
1401
1402
1403
        if all(t not in model_config.supported_tasks
               for t in ("embed", "classify")):
            raise ValueError("Score API is not supported by this model. "
                             "Please set `--task embed` or `--task classify`.")
1404

1405
1406
        if (model_config.task == "classify"
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1407
            raise ValueError("Score API is only enabled for num_labels == 1.")
1408
1409
1410
1411

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1412
        tokenizer = self.get_tokenizer()
1413

1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
        if not self.llm_engine.model_config.is_multimodal_model:

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
                    raise ValueError(
                        f"ScoreMultiModalParam is not supported for {self.llm_engine.model_config.architecture}",  # noqa: E501
                    )

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1463

1464
        if self.llm_engine.model_config.is_cross_encoder:
1465
1466
1467
1468
1469
1470
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1471
                lora_request)
1472
        else:
1473
1474
            return self._embedding_score(
                tokenizer,
1475
1476
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1477
1478
                truncate_prompt_tokens,
                use_tqdm,
1479
                lora_request)
1480

1481
1482
1483
1484
1485
1486
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1487
1488
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1489

1490
1491
1492
1493
1494
1495
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1496
        Args:
1497
1498
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1499
                is forgotten. Level 1 sleep is good for sleeping and waking
1500
1501
1502
1503
1504
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1505
                sleep is good for sleeping and waking up the engine to run a
1506
                different model or update the model, where previous model
1507
                weights are not needed. It reduces CPU memory pressure.
1508
        """
1509
        self.reset_prefix_cache()
1510
1511
        self.llm_engine.sleep(level=level)

1512
    def wake_up(self, tags: Optional[list[str]] = None):
1513
        """
1514
        Wake up the engine from sleep mode. See the [sleep][] method
1515
        for more details.
1516

1517
        Args:
1518
1519
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1520
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1521
                wake_up should be called with all tags (or None) before the
1522
1523
1524
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1525

1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1540
1541
    # LEGACY
    def _convert_v1_inputs(
1542
        self,
1543
1544
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1545
1546
    ):
        # skip_tokenizer_init is now checked in engine
1547

1548
1549
1550
1551
1552
1553
1554
1555
1556
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1557
1558
1559
1560
1561
1562
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1563
1564
        if prompts is not None:
            num_requests = len(prompts)
1565
        elif prompt_token_ids is not None:
1566
            num_requests = len(prompt_token_ids)
1567
        parsed_prompts: list[PromptType] = []
1568
        for i in range(num_requests):
1569
            item: PromptType
1570

1571
            if prompts is not None:
1572
1573
1574
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1575
            else:
1576
                raise AssertionError
1577

1578
            parsed_prompts.append(item)
1579

1580
        return parsed_prompts
1581
1582
1583

    def _validate_and_add_requests(
        self,
1584
        prompts: Union[PromptType, Sequence[PromptType]],
1585
1586
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1587
        *,
1588
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1589
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1590
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1591
        guided_options: Optional[GuidedDecodingRequest] = None,
1592
        priority: Optional[list[int]] = None,
1593
    ) -> None:
1594
1595
1596
1597
1598
1599
1600
1601
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1602
        if isinstance(prompts, (str, dict)):
1603
            # Convert a single prompt to a list.
1604
            prompts = [prompts]
1605

1606
        num_requests = len(prompts)
1607
        if isinstance(params, Sequence) and len(params) != num_requests:
1608
            raise ValueError("The lengths of prompts and params "
1609
                             "must be the same.")
1610
        if isinstance(lora_request,
1611
                      Sequence) and len(lora_request) != num_requests:
1612
1613
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1614

1615
        for sp in params if isinstance(params, Sequence) else (params, ):
1616
            if isinstance(sp, SamplingParams):
1617
                self._add_guided_params(sp, guided_options)
1618
1619
1620

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1621

Zhuohan Li's avatar
Zhuohan Li committed
1622
        # Add requests to the engine.
1623
1624
        it = prompts
        if use_tqdm:
1625
1626
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1627
1628

        for i, prompt in enumerate(it):
1629
            self._add_request(
1630
                prompt,
1631
                params[i] if isinstance(params, Sequence) else params,
1632
                tokenization_kwargs=tokenization_kwargs,
1633
1634
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1635
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1636
            )
1637

1638
    def _add_request(
nunjunj's avatar
nunjunj committed
1639
        self,
1640
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1641
        params: Union[SamplingParams, PoolingParams],
1642
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1643
        lora_request: Optional[LoRARequest] = None,
1644
        priority: int = 0,
1645
1646
    ) -> None:
        request_id = str(next(self.request_counter))
1647
1648
        self.llm_engine.add_request(
            request_id,
1649
            prompt,
1650
1651
            params,
            lora_request=lora_request,
1652
            tokenization_kwargs=tokenization_kwargs,
1653
            priority=priority,
nunjunj's avatar
nunjunj committed
1654
        )
1655

1656
    def _add_guided_params(
1657
1658
1659
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1660
1661
1662
1663
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1664
            raise ValueError("Cannot set both guided_options_request and "
1665
1666
1667
1668
1669
1670
1671
1672
1673
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1674
1675
1676
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1677
1678
        return params

1679
    def _run_engine(
1680
1681
1682
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1683
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1684
1685
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1686
            num_requests = self.llm_engine.get_num_unfinished_requests()
1687
1688
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1689
1690
1691
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1692
1693
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1694
            )
1695

Zhuohan Li's avatar
Zhuohan Li committed
1696
        # Run the engine.
1697
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1698
1699
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1700
1701
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1702
            for output in step_outputs:
1703
                if output.finished:
1704
1705
                    outputs.append(output)
                    if use_tqdm:
1706
1707
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1708
                            n = len(output.outputs)
1709
                            assert output.prompt_token_ids is not None
1710
                            total_in_toks += len(output.prompt_token_ids) * n
1711
1712
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1713
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1714
1715
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1716
1717
1718
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1719
                            pbar.update(n)
1720
1721
                        else:
                            pbar.update(1)
1722
1723
                        if pbar.n == num_requests:
                            pbar.refresh()
1724

1725
1726
        if use_tqdm:
            pbar.close()
1727
1728
1729
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1730
        return sorted(outputs, key=lambda x: int(x.request_id))