llm.py 74.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
import vllm.envs as envs
18
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
19
20
                              BeamSearchSequence,
                              create_sort_beams_key_function)
21
22
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
23
24
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
Joe Runde's avatar
Joe Runde committed
25
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
26
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
27
                                         ChatTemplateContentFormatOption,
28
29
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
30
31
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
32
33
34
35
36
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
                                          get_score_prompt)
37
38
from vllm.entrypoints.utils import (_validate_truncation_size,
                                    log_non_default_args)
39
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
40
from vllm.inputs.parse import parse_and_batch_prompt
41
from vllm.logger import init_logger
42
from vllm.lora.request import LoRARequest
43
44
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
45
from vllm.model_executor.layers.quantization import QuantizationMethods
46
47
48
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
49
from vllm.pooling_params import PoolingParams
50
51
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
52
from vllm.tasks import PoolingTask
53
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
54
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
55
from vllm.usage.usage_lib import UsageContext
56
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
57

58
59
60
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

61
62
logger = init_logger(__name__)

63
64
_R = TypeVar("_R", default=Any)

65
66

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72
73
74
75
76
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
77
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
78
79
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
80
81
82
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
83
84
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
85
86
87
88
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
89
90
91
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
92
93
94
95
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
96
        quantization: The method used to quantize the model weights. Currently,
97
            we support "awq", "gptq", and "fp8" (experimental).
98
99
100
101
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
102
103
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
104
105
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
106
107
108
109
110
111
112
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
113
114
115
116
117
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
118
119
120
121
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
122
123
124
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
125
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
126
            When a sequence has context length larger than this, we fall back
127
128
129
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
130
131
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
132
133
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
134
        hf_token: The token to use as HTTP bearer authorization for remote files
135
            . If `True`, will use the token generated when running
136
            `huggingface-cli login` (stored in `~/.huggingface`).
137
138
139
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
140
141
142
143
144
145
146
147
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
148
149
150
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
151
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
152

153
154
    Note:
        This class is intended to be used for offline inference. For online
155
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
156
    """
157

158
    DEPRECATE_LEGACY: ClassVar[bool] = True
159
160
161
162
163
164
165
166
167
168
169
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

170
171
172
    def __init__(
        self,
        model: str,
173
        *,
174
175
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
176
        tokenizer: Optional[str] = None,
177
        tokenizer_mode: TokenizerMode = "auto",
178
        skip_tokenizer_init: bool = False,
179
        trust_remote_code: bool = False,
180
        allowed_local_media_path: str = "",
181
        tensor_parallel_size: int = 1,
182
183
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
184
        revision: Optional[str] = None,
185
        tokenizer_revision: Optional[str] = None,
186
        seed: Optional[int] = None,
187
        gpu_memory_utilization: float = 0.9,
188
        swap_space: float = 4,
189
        cpu_offload_gb: float = 0,
190
        enforce_eager: bool = False,
191
        max_seq_len_to_capture: int = 8192,
192
        disable_custom_all_reduce: bool = False,
193
        disable_async_output_proc: bool = False,
194
        hf_token: Optional[Union[bool, str]] = None,
195
        hf_overrides: Optional[HfOverrides] = None,
196
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
197
        override_pooler_config: Optional[PoolerConfig] = None,
198
199
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
200
201
        **kwargs,
    ) -> None:
202
        """LLM constructor."""
203

204
205
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
206

207
208
209
210
211
212
213
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

231
232
233
        if hf_overrides is None:
            hf_overrides = {}

234
        if compilation_config is not None:
235
236
237
238
239
240
241
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
242
243
            else:
                compilation_config_instance = compilation_config
244
        else:
245
            compilation_config_instance = CompilationConfig()
246

Zhuohan Li's avatar
Zhuohan Li committed
247
        engine_args = EngineArgs(
248
            model=model,
249
250
            runner=runner,
            convert=convert,
251
            tokenizer=tokenizer,
252
            tokenizer_mode=tokenizer_mode,
253
            skip_tokenizer_init=skip_tokenizer_init,
254
            trust_remote_code=trust_remote_code,
255
            allowed_local_media_path=allowed_local_media_path,
256
257
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
258
            quantization=quantization,
259
            revision=revision,
260
            tokenizer_revision=tokenizer_revision,
261
262
263
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
264
            cpu_offload_gb=cpu_offload_gb,
265
            enforce_eager=enforce_eager,
266
            max_seq_len_to_capture=max_seq_len_to_capture,
267
            disable_custom_all_reduce=disable_custom_all_reduce,
268
            disable_async_output_proc=disable_async_output_proc,
269
            hf_token=hf_token,
270
            hf_overrides=hf_overrides,
271
            mm_processor_kwargs=mm_processor_kwargs,
272
            override_pooler_config=override_pooler_config,
273
            compilation_config=compilation_config_instance,
274
275
            **kwargs,
        )
276

277
278
        log_non_default_args(engine_args)

279
280
281
282
        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
283

284
        self.request_counter = Counter()
285
        self.default_sampling_params: Union[dict[str, Any], None] = None
286

287
288
289
290
291
292
293
294
295
296
        if envs.VLLM_USE_V1:
            supported_tasks = self.llm_engine \
                .get_supported_tasks()  # type: ignore
        else:
            supported_tasks = self.llm_engine.model_config.supported_tasks

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

297
298
299
300
301
302
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
303
304

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
305
        tokenizer_group = self.llm_engine.get_tokenizer_group()
306

307
308
309
310
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
311
            tokenizer_group.tokenizer = tokenizer
312
        else:
313
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
314

315
    def get_default_sampling_params(self) -> SamplingParams:
316
317
318
319
320
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
321
322
        return SamplingParams()

323
324
325
326
327
328
329
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
330
        *,
331
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
332
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
333
334
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
335
    ) -> list[RequestOutput]:
336
337
        ...

338
    @overload  # LEGACY: single (prompt + optional token ids)
339
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
340
341
342
343
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
344
345
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
346
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
347
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
348
349
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
350
    ) -> list[RequestOutput]:
351
352
353
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
354
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
355
356
    def generate(
        self,
357
        prompts: list[str],
358
        sampling_params: Optional[Union[SamplingParams,
359
360
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
361
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
362
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
363
364
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
365
    ) -> list[RequestOutput]:
366
367
368
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
369
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
370
371
372
373
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
374
                                        list[SamplingParams]]] = None,
375
        *,
376
        prompt_token_ids: list[int],
377
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
378
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
379
380
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
381
    ) -> list[RequestOutput]:
382
383
384
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
385
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
386
387
    def generate(
        self,
388
        prompts: Optional[list[str]] = None,
389
        sampling_params: Optional[Union[SamplingParams,
390
                                        list[SamplingParams]]] = None,
391
        *,
392
        prompt_token_ids: list[list[int]],
393
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
394
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
395
396
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
397
    ) -> list[RequestOutput]:
398
399
400
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
401
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
402
403
404
405
    def generate(
        self,
        prompts: None,
        sampling_params: None,
406
        prompt_token_ids: Union[list[int], list[list[int]]],
407
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
408
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
409
410
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
411
    ) -> list[RequestOutput]:
412
413
        ...

nunjunj's avatar
nunjunj committed
414
415
416
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
417
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
418
    )
419
420
    def generate(
        self,
421
        prompts: Union[Union[PromptType, Sequence[PromptType]],
422
                       Optional[Union[str, list[str]]]] = None,
423
424
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
425
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
426
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
427
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
428
        guided_options_request: Optional[Union[LLMGuidedOptions,
429
                                               GuidedDecodingRequest]] = None,
430
431
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
432
433
        """Generates the completions for the input prompts.

434
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
435
436
437
438
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
439
            prompts: The prompts to the LLM. You may pass a sequence of prompts
440
                for batch inference. See [PromptType][vllm.inputs.PromptType]
441
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
442
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
443
444
445
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
446
                prompts and it is paired one by one with the prompt.
447
448
449
450
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
451
            lora_request: LoRA request to use for generation, if any.
452
453
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
454
455

        Returns:
456
            A list of `RequestOutput` objects containing the
457
            generated completions in the same order as the input prompts.
458

459
460
461
462
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
463
        """
464
465
466
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
467
468
469
470
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
471

472
        if prompt_token_ids is not None:
473
            parsed_prompts = self._convert_v1_inputs(
474
                prompts=cast(Optional[Union[str, list[str]]], prompts),
475
476
477
                prompt_token_ids=prompt_token_ids,
            )
        else:
478
479
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
480

481
482
483
484
485
486
487
488
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

489
490
        if sampling_params is None:
            # Use default sampling params.
491
            sampling_params = self.get_default_sampling_params()
492

493
494
495
496
        tokenization_kwargs: dict[str, Any] = {}
        truncate_prompt_tokens = None
        if isinstance(sampling_params, SamplingParams):
            truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
497
498

        _validate_truncation_size(model_config.max_model_len,
499
500
                                  truncate_prompt_tokens, tokenization_kwargs)

501
502
503
504
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
            parsed_prompts, lora_request)

505
        self._validate_and_add_requests(
506
            prompts=parsed_prompts,
507
            params=sampling_params,
508
            use_tqdm=use_tqdm,
509
            lora_request=lora_request,
510
            guided_options=guided_options_request,
511
            tokenization_kwargs=tokenization_kwargs,
512
513
            priority=priority,
        )
514

515
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
516
        return self.engine_class.validate_outputs(outputs, RequestOutput)
517

518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
    def _get_modality_specific_lora_reqs(
            self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

        if not isinstance(parsed_prompts, Sequence):
            parsed_prompts = [parsed_prompts]

        optional_loras = ([lora_request] * len(parsed_prompts)
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
                parsed_prompt,
                opt_lora_req,
                lora_config.default_mm_loras,
            ) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
                                                     optional_loras)
        ]

    def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
        if (not default_mm_loras or not isinstance(parsed_prompt, dict)
                or "multi_modal_data" not in parsed_prompt):
            return lora_request

        parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)

        intersection = set(
            parsed_prompt["multi_modal_data"].keys()).intersection(
                default_mm_loras.keys())
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

595
    def collective_rpc(self,
596
                       method: Union[str, Callable[..., _R]],
597
                       timeout: Optional[float] = None,
598
599
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
600
601
602
603
604
605
606
607
608
609
610
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
611
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
612
613
614
615
616
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
617

618
619
620
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
621
        """
622
623

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
624
625

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
626
        """
627
628
        Run a function directly on the model inside each worker,
        returning the result for each of them.
629
        """
630
631
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
632

633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

649
650
    def beam_search(
        self,
651
        prompts: list[Union[TokensPrompt, TextPrompt]],
652
        params: BeamSearchParams,
653
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
654
        use_tqdm: bool = False,
655
    ) -> list[BeamSearchOutput]:
656
657
658
659
660
661
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
662
            params: The beam search parameters.
663
            lora_request: LoRA request to use for generation, if any.
664
            use_tqdm: Whether to use tqdm to display the progress bar.
665
        """
666
667
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
668
669
670
671
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
672
673
        length_penalty = params.length_penalty

674
675
676
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

677
678
679
680
681
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
682

683
684
685
686
687
688
689
690
691
692
693
694
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
695

696
697
698
699
700
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
701
                                            temperature=temperature)
702
        instances: list[BeamSearchInstance] = []
703

704
        for lora_req, prompt in zip(lora_requests, prompts):
705
706
707
708
709
710
711
712
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

713
714
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
715
716
717
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
718

719
            instances.append(
720
721
722
723
724
725
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
726

727
728
729
730
731
732
733
734
735
736
737
738
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
739
            all_beams: list[BeamSearchSequence] = list(
740
741
742
743
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
744
            instance_start_and_end: list[tuple[int, int]] = list(
745
746
747
748
749
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

750
751
752
753
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
754
755
756
757
758

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
759
760
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
777
                                logprobs=current_beam.logprobs + [logprobs],
778
                                lora_request=current_beam.lora_request,
779
                                cum_logprob=current_beam.cum_logprob +
780
781
782
783
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
784
785
786
787
788
789
790

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
791
                                      key=sort_beams_key,
792
793
794
795
796
797
798
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
799
                                      key=sort_beams_key,
800
801
802
803
804
805
806
807
808
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
809
810
    def chat(
        self,
811
812
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
813
        sampling_params: Optional[Union[SamplingParams,
814
                                        list[SamplingParams]]] = None,
815
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
816
817
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
818
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
819
        add_generation_prompt: bool = True,
820
        continue_final_message: bool = False,
821
        tools: Optional[list[dict[str, Any]]] = None,
822
        chat_template_kwargs: Optional[dict[str, Any]] = None,
823
824
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
825
        """
826
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
827

828
        The chat conversation is converted into a text prompt using the
829
        tokenizer and calls the [generate][] method to generate the
830
831
832
833
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
834
835

        Args:
836
837
            messages: A list of conversations or a single conversation.

838
839
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
840

nunjunj's avatar
nunjunj committed
841
842
843
844
845
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
846
847
848
849
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
850
851
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
852
                If not provided, the model's default chat template will be used.
853
854
            chat_template_content_format: The format to render message content.

855
856
857
858
859
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
860

861
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
862
                to each message.
863
            continue_final_message: If True, continues the final message in
864
                the conversation instead of starting a new one. Cannot be
865
                `True` if `add_generation_prompt` is also `True`.
866
867
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
868
869
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
870
871

        Returns:
872
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
873
874
            responses in the same order as the input messages.
        """
875
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
876

877
878
        # Handle multi and single conversations
        if is_list_of(messages, list):
879
880
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
881
                                    messages)
882
        else:
883
            # messages is list[...]
884
            list_of_messages = [
885
                cast(list[ChatCompletionMessageParam], messages)
886
            ]
887

888
        tokenizer = self.get_tokenizer(lora_request)
889
890
891
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
892
            tools,
893
894
            chat_template_content_format,
            tokenizer,
895
            model_config=model_config,
896
897
        )

898
899
900
901
902
903
904
905
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

906
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
907
908

        for msgs in list_of_messages:
909
910
911
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
912
            conversation, mm_data = parse_chat_messages(
913
914
915
916
917
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
918
919

            if isinstance(tokenizer, MistralTokenizer):
920
                prompt_token_ids = apply_mistral_chat_template(
921
922
                    tokenizer,
                    messages=msgs,
923
                    **_chat_template_kwargs,
924
925
                )
            else:
926
                prompt_str = apply_hf_chat_template(
927
                    tokenizer=tokenizer,
928
                    conversation=conversation,
929
                    model_config=model_config,
930
                    **_chat_template_kwargs,
931
                )
932
933
934
935
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
936

937
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
938
939
940
941

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

942
943
944
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

945
            prompts.append(prompt)
946

nunjunj's avatar
nunjunj committed
947
        return self.generate(
948
            prompts,
949
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
950
951
952
953
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

954
955
956
957
958
959
960
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
961
        *,
962
        truncate_prompt_tokens: Optional[int] = None,
963
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
964
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
965
        pooling_task: PoolingTask = "encode",
966
        tokenization_kwargs: Optional[dict[str, Any]] = None,
967
    ) -> list[PoolingRequestOutput]:
968
969
        ...

970
    @overload  # LEGACY: single (prompt + optional token ids)
971
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
972
973
974
975
976
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
977
        prompt_token_ids: Optional[list[int]] = None,
978
        truncate_prompt_tokens: Optional[int] = None,
979
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
980
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
981
        pooling_task: PoolingTask = "encode",
982
        tokenization_kwargs: Optional[dict[str, Any]] = None,
983
    ) -> list[PoolingRequestOutput]:
984
        ...
985

986
    @overload  # LEGACY: multi (prompt + optional token ids)
987
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
988
989
    def encode(
        self,
990
        prompts: list[str],
991
        pooling_params: Optional[Union[PoolingParams,
992
                                       Sequence[PoolingParams]]] = None,
993
        prompt_token_ids: Optional[list[list[int]]] = None,
994
        truncate_prompt_tokens: Optional[int] = None,
995
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
996
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
997
        pooling_task: PoolingTask = "encode",
998
        tokenization_kwargs: Optional[dict[str, Any]] = None,
999
    ) -> list[PoolingRequestOutput]:
1000
1001
1002
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
1003
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1004
1005
1006
1007
1008
1009
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1010
        prompt_token_ids: list[int],
1011
        truncate_prompt_tokens: Optional[int] = None,
1012
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1013
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1014
        pooling_task: PoolingTask = "encode",
1015
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1016
    ) -> list[PoolingRequestOutput]:
1017
1018
1019
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
1020
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1021
1022
    def encode(
        self,
1023
        prompts: Optional[list[str]] = None,
1024
1025
1026
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1027
        prompt_token_ids: list[list[int]],
1028
        truncate_prompt_tokens: Optional[int] = None,
1029
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1030
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1031
        pooling_task: PoolingTask = "encode",
1032
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1033
    ) -> list[PoolingRequestOutput]:
1034
1035
1036
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
1037
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1038
1039
1040
1041
    def encode(
        self,
        prompts: None,
        pooling_params: None,
1042
        prompt_token_ids: Union[list[int], list[list[int]]],
1043
        truncate_prompt_tokens: Optional[int] = None,
1044
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1045
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1046
        pooling_task: PoolingTask = "encode",
1047
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1048
    ) -> list[PoolingRequestOutput]:
1049
1050
        ...

nunjunj's avatar
nunjunj committed
1051
1052
1053
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
1054
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
1055
    )
1056
1057
    def encode(
        self,
1058
        prompts: Union[Union[PromptType, Sequence[PromptType]],
1059
                       Optional[Union[str, list[str]]]] = None,
1060
1061
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1062
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
1063
        truncate_prompt_tokens: Optional[int] = None,
1064
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1065
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1066
        pooling_task: PoolingTask = "encode",
1067
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1068
    ) -> list[PoolingRequestOutput]:
1069
1070
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1071

1072
        This class automatically batches the given prompts, considering
1073
1074
1075
1076
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1077
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1078
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1079
                for more details about the format of each prompts.
1080
1081
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1082
1083
1084
1085
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1086
            lora_request: LoRA request to use for generation, if any.
1087
            pooling_task: Override the pooling task to use.
1088
1089

        Returns:
1090
            A list of `PoolingRequestOutput` objects containing the
1091
            pooled hidden states in the same order as the input prompts.
1092

1093
1094
1095
1096
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1097
        """
1098
1099
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1100
        if runner_type != "pooling":
1101
1102
1103
1104
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1105

1106
        if prompt_token_ids is not None:
1107
            parsed_prompts = self._convert_v1_inputs(
1108
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1109
1110
1111
                prompt_token_ids=prompt_token_ids,
            )
        else:
1112
1113
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1114

1115
1116
1117
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1118
1119
1120

        if isinstance(pooling_params, PoolingParams):
            pooling_params.verify(pooling_task, model_config)
1121
1122
        else:
            for pooling_param in pooling_params:
1123
                pooling_param.verify(pooling_task, model_config)
1124

1125
1126
1127
1128
1129
        if tokenization_kwargs is None:
            tokenization_kwargs = dict[str, Any]()
            _validate_truncation_size(model_config.max_model_len,
                                      truncate_prompt_tokens,
                                      tokenization_kwargs)
1130

1131
        self._validate_and_add_requests(
1132
            prompts=parsed_prompts,
1133
            params=pooling_params,
1134
            use_tqdm=use_tqdm,
1135
            lora_request=lora_request,
1136
            tokenization_kwargs=tokenization_kwargs,
1137
1138
        )

1139
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1140
        return self.engine_class.validate_outputs(outputs,
1141
                                                  PoolingRequestOutput)
1142

1143
1144
1145
1146
1147
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1148
        truncate_prompt_tokens: Optional[int] = None,
1149
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1150
1151
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1152
1153
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1154
1155
1156
1157
1158
1159
1160
1161
1162
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1163
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1164
                for more details about the format of each prompts.
1165
1166
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1167
1168
1169
1170
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1171
1172
1173
            lora_request: LoRA request to use for generation, if any.

        Returns:
1174
            A list of `EmbeddingRequestOutput` objects containing the
1175
1176
            embedding vectors in the same order as the input prompts.
        """
1177
        if "embed" not in self.supported_tasks:
1178
1179
1180
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1181

1182
1183
1184
1185
1186
1187
1188
1189
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1190
1191
1192
1193
1194
1195
1196
1197

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1198
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1199
1200
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1201
1202
1203
1204
1205
1206
1207
1208
1209
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1210
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1211
                for more details about the format of each prompts.
1212
1213
1214
1215
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1216
1217
1218
            lora_request: LoRA request to use for generation, if any.

        Returns:
1219
            A list of `ClassificationRequestOutput` objects containing the
1220
1221
            embedding vectors in the same order as the input prompts.
        """
1222
        if "classify" not in self.supported_tasks:
1223
            raise ValueError(
1224
                "Classification API is not supported by this model. "
1225
                "Try converting the model using `--convert classify`.")
1226

1227
1228
1229
1230
1231
1232
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_task="classify",
        )
1233
1234
1235

        return [ClassificationRequestOutput.from_base(item) for item in items]

1236
1237
1238
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1239
1240
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1241
        truncate_prompt_tokens: Optional[int] = None,
1242
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1243
1244
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1245

1246
        encoded_output: list[PoolingRequestOutput] = self.encode(
1247
            text_1 + text_2,
1248
            truncate_prompt_tokens=truncate_prompt_tokens,
1249
1250
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1251
1252
            pooling_task="embed",
        )
1253

1254
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1255
            0:len(text_1)]
1256
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1257
            len(text_1):]
1258
1259
1260
1261

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1262
1263
1264
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1265
1266
1267
1268
1269
1270
1271

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1272
        tokenizer: AnyTokenizer,
1273
1274
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1275
        truncate_prompt_tokens: Optional[int] = None,
1276
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1277
1278
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1279
        model_config = self.llm_engine.model_config
1280
1281
1282

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1283
                "Score API is not supported for Mistral tokenizer")
1284

1285
1286
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1287

1288
        pooling_params = PoolingParams(task="score")
1289
        tokenization_kwargs: dict[str, Any] = {}
1290
1291

        _validate_truncation_size(model_config.max_model_len,
1292
                                  truncate_prompt_tokens, tokenization_kwargs)
1293
1294
1295

        parsed_prompts = []

1296
1297
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1298
        if model_config.is_multimodal_model:
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
            for q, d in input_pairs:
                _, engine_prompt = get_score_prompt(
                    model_config=model_config,
                    data_1=q,
                    data_2=d,
                    tokenizer=tokenizer,
                    tokenization_kwargs=tokenization_kwargs,
                )

                parsed_prompts.append(engine_prompt)
        else:
            for q, t in input_pairs:
1311
                if model_config.use_pad_token:
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
                    # cross_encoder models defaults to using pad_token.
                    prompt_inputs = tokenizer(
                        text=q,  # type: ignore[arg-type]
                        text_pair=t,  # type: ignore[arg-type]
                        **tokenization_kwargs)
                else:
                    # `llm as reranker` models defaults to not using pad_token.
                    prompt_inputs = tokenizer(
                        text=q + t,  # type: ignore[operator]
                        **tokenization_kwargs)
                engine_prompt = TokensPrompt(
                    prompt_token_ids=prompt_inputs["input_ids"],
                    token_type_ids=prompt_inputs.get("token_type_ids"))
                parsed_prompts.append(engine_prompt)
1326
1327
1328
1329

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1330
            use_tqdm=use_tqdm,
1331
1332
1333
1334
1335
1336
1337
1338
1339
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1340
1341
    def score(
        self,
1342
1343
1344
1345
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1346
        /,
1347
        *,
1348
        truncate_prompt_tokens: Optional[int] = None,
1349
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1350
1351
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1352
1353
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1354

1355
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1356
1357
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1358
        The input pairs are used to build a list of prompts for the
1359
1360
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1361
1362
1363
1364
1365
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
        appropriate multi-modal models. For multi-modal inputs, ensure the 
        prompt structure matches the model's expected input format.
1366
1367

        Args:
1368
1369
1370
1371
1372
1373
1374
1375
            data_1: Can be a single prompt, a list of prompts or 
                `ScoreMultiModalParam`, which can contain either text or 
                multi-modal data. When a list, it must have the same length as 
                the `data_2` list.
            data_2: The data to pair with the query to form the input to 
                the LLM. Can be text or multi-modal data. See [PromptType]
                [vllm.inputs.PromptType] for more details about the format of 
                each prompt.
1376
1377
1378
1379
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1380
1381
1382
            lora_request: LoRA request to use for generation, if any.

        Returns:
1383
            A list of `ScoringRequestOutput` objects containing the
1384
1385
            generated scores in the same order as the input prompts.
        """
1386
1387
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1388
        if runner_type != "pooling":
1389
1390
1391
1392
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1393

1394
1395
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1396
            raise ValueError("Score API is not supported by this model. "
1397
1398
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1399

1400
        if (model_config.is_cross_encoder
1401
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1402
            raise ValueError("Score API is only enabled for num_labels == 1.")
1403
1404
1405
1406

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1407
        tokenizer = self.get_tokenizer()
1408

1409
        if not model_config.is_multimodal_model:
1410
1411
1412
1413
1414

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1415
1416
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1457

1458
        if model_config.is_cross_encoder:
1459
1460
1461
1462
1463
1464
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1465
                lora_request)
1466
        else:
1467
1468
            return self._embedding_score(
                tokenizer,
1469
1470
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1471
1472
                truncate_prompt_tokens,
                use_tqdm,
1473
                lora_request)
1474

1475
1476
1477
1478
1479
1480
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1481
1482
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1483

1484
1485
1486
1487
1488
1489
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1490
        Args:
1491
1492
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1493
                is forgotten. Level 1 sleep is good for sleeping and waking
1494
1495
1496
1497
1498
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1499
                sleep is good for sleeping and waking up the engine to run a
1500
                different model or update the model, where previous model
1501
                weights are not needed. It reduces CPU memory pressure.
1502
        """
1503
        self.reset_prefix_cache()
1504
1505
        self.llm_engine.sleep(level=level)

1506
    def wake_up(self, tags: Optional[list[str]] = None):
1507
        """
1508
        Wake up the engine from sleep mode. See the [sleep][] method
1509
        for more details.
1510

1511
        Args:
1512
1513
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1514
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1515
                wake_up should be called with all tags (or None) before the
1516
1517
1518
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1519

1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1534
1535
    # LEGACY
    def _convert_v1_inputs(
1536
        self,
1537
1538
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1539
1540
    ):
        # skip_tokenizer_init is now checked in engine
1541

1542
1543
1544
1545
1546
1547
1548
1549
1550
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1551
1552
1553
1554
1555
1556
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1557
1558
        if prompts is not None:
            num_requests = len(prompts)
1559
        elif prompt_token_ids is not None:
1560
            num_requests = len(prompt_token_ids)
1561
        parsed_prompts: list[PromptType] = []
1562
        for i in range(num_requests):
1563
            item: PromptType
1564

1565
            if prompts is not None:
1566
1567
1568
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1569
            else:
1570
                raise AssertionError
1571

1572
            parsed_prompts.append(item)
1573

1574
        return parsed_prompts
1575
1576
1577

    def _validate_and_add_requests(
        self,
1578
        prompts: Union[PromptType, Sequence[PromptType]],
1579
1580
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1581
        *,
1582
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1583
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1584
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1585
        guided_options: Optional[GuidedDecodingRequest] = None,
1586
        priority: Optional[list[int]] = None,
1587
    ) -> None:
1588
1589
1590
1591
1592
1593
1594
1595
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1596
        if isinstance(prompts, (str, dict)):
1597
            # Convert a single prompt to a list.
1598
            prompts = [prompts]
1599

1600
        num_requests = len(prompts)
1601
        if isinstance(params, Sequence) and len(params) != num_requests:
1602
            raise ValueError("The lengths of prompts and params "
1603
                             "must be the same.")
1604
        if isinstance(lora_request,
1605
                      Sequence) and len(lora_request) != num_requests:
1606
1607
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1608

1609
        for sp in params if isinstance(params, Sequence) else (params, ):
1610
            if isinstance(sp, SamplingParams):
1611
                self._add_guided_params(sp, guided_options)
1612
1613
1614

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1615

Zhuohan Li's avatar
Zhuohan Li committed
1616
        # Add requests to the engine.
1617
1618
        it = prompts
        if use_tqdm:
1619
1620
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1621
1622

        for i, prompt in enumerate(it):
1623
            self._add_request(
1624
                prompt,
1625
                params[i] if isinstance(params, Sequence) else params,
1626
                tokenization_kwargs=tokenization_kwargs,
1627
1628
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1629
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1630
            )
1631

1632
    def _add_request(
nunjunj's avatar
nunjunj committed
1633
        self,
1634
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1635
        params: Union[SamplingParams, PoolingParams],
1636
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1637
        lora_request: Optional[LoRARequest] = None,
1638
        priority: int = 0,
1639
1640
    ) -> None:
        request_id = str(next(self.request_counter))
1641
1642
        self.llm_engine.add_request(
            request_id,
1643
            prompt,
1644
1645
            params,
            lora_request=lora_request,
1646
            tokenization_kwargs=tokenization_kwargs,
1647
            priority=priority,
nunjunj's avatar
nunjunj committed
1648
        )
1649

1650
    def _add_guided_params(
1651
1652
1653
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1654
1655
1656
1657
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1658
            raise ValueError("Cannot set both guided_options_request and "
1659
1660
1661
1662
1663
1664
1665
1666
1667
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1668
1669
1670
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1671
1672
        return params

1673
    def _run_engine(
1674
1675
1676
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1677
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1678
1679
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1680
            num_requests = self.llm_engine.get_num_unfinished_requests()
1681
1682
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1683
1684
1685
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1686
1687
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1688
            )
1689

Zhuohan Li's avatar
Zhuohan Li committed
1690
        # Run the engine.
1691
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1692
1693
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1694
1695
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1696
            for output in step_outputs:
1697
                if output.finished:
1698
1699
                    outputs.append(output)
                    if use_tqdm:
1700
1701
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1702
                            n = len(output.outputs)
1703
                            assert output.prompt_token_ids is not None
1704
                            total_in_toks += len(output.prompt_token_ids) * n
1705
1706
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1707
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1708
1709
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1710
1711
1712
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1713
                            pbar.update(n)
1714
1715
                        else:
                            pbar.update(1)
1716
1717
                        if pbar.n == num_requests:
                            pbar.refresh()
1718

1719
1720
        if use_tqdm:
            pbar.close()
1721
1722
1723
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1724
        return sorted(outputs, key=lambda x: int(x.request_id))