llm.py 74.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
import itertools
5
import warnings
6
from collections.abc import Sequence
7
from contextlib import contextmanager
8
9
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Optional, Union,
                    cast, overload)
10

11
import cloudpickle
12
import torch.nn as nn
13
from pydantic import ValidationError
14
from tqdm.auto import tqdm
15
from typing_extensions import TypeVar, deprecated
16

17
import vllm.envs as envs
18
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
19
20
                              BeamSearchSequence,
                              create_sort_beams_key_function)
21
22
from vllm.config import (CompilationConfig, ModelDType, TokenizerMode,
                         is_init_field)
23
24
from vllm.engine.arg_utils import (ConvertOption, EngineArgs, HfOverrides,
                                   PoolerConfig, RunnerOption)
Joe Runde's avatar
Joe Runde committed
25
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
26
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
27
                                         ChatTemplateContentFormatOption,
28
29
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
30
31
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
32
33
34
35
36
from vllm.entrypoints.score_utils import (ScoreContentPartParam,
                                          ScoreMultiModalParam,
                                          _cosine_similarity,
                                          _validate_score_input_lens,
                                          get_score_prompt)
37
from vllm.entrypoints.utils import _validate_truncation_size
38
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
39
from vllm.inputs.parse import parse_and_batch_prompt
40
from vllm.logger import init_logger
41
from vllm.lora.request import LoRARequest
42
43
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
44
from vllm.model_executor.layers.quantization import QuantizationMethods
45
46
47
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
48
from vllm.pooling_params import PoolingParams
49
50
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
51
from vllm.tasks import PoolingTask
52
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
53
                                               get_cached_tokenizer)
yhu422's avatar
yhu422 committed
54
from vllm.usage.usage_lib import UsageContext
55
from vllm.utils import Counter, Device, deprecate_kwargs, is_list_of
56

57
58
59
if TYPE_CHECKING:
    from vllm.v1.metrics.reader import Metric

60
61
logger = init_logger(__name__)

62
63
_R = TypeVar("_R", default=Any)

64
65

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
69
70
71
72
73
74
75
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
76
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
77
78
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
79
80
81
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
82
83
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
84
85
86
87
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
91
92
93
94
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
95
        quantization: The method used to quantize the model weights. Currently,
96
            we support "awq", "gptq", and "fp8" (experimental).
97
98
99
100
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
101
102
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
103
104
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
105
106
107
108
109
110
111
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
112
113
114
115
116
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Noting that `best_of` is only supported in V0. Otherwise, too small
            values may cause out-of-memory (OOM) errors.
117
118
119
120
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
121
122
123
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
124
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
125
            When a sequence has context length larger than this, we fall back
126
127
128
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
129
130
        disable_custom_all_reduce: See
            [ParallelConfig][vllm.config.ParallelConfig].
131
132
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
133
        hf_token: The token to use as HTTP bearer authorization for remote files
134
            . If `True`, will use the token generated when running
135
            `huggingface-cli login` (stored in `~/.huggingface`).
136
137
138
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
139
140
141
142
143
144
145
146
        mm_processor_kwargs: Arguments to be forwarded to the model's processor
            for multi-modal data, e.g., image processor. Overrides for the
            multi-modal processor obtained from `AutoProcessor.from_pretrained`.
            The available overrides depend on the model that is being run.
            For example, for Phi-3-Vision: `{"num_crops": 4}`.
        override_pooler_config: Initialize non-default pooling config or
            override default pooling config for the pooling model.
            e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
147
148
149
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
150
        **kwargs: Arguments for [`EngineArgs`][vllm.EngineArgs].
nunjunj's avatar
nunjunj committed
151

152
153
    Note:
        This class is intended to be used for offline inference. For online
154
        serving, use the [AsyncLLMEngine][vllm.AsyncLLMEngine] class instead.
155
    """
156

157
    DEPRECATE_LEGACY: ClassVar[bool] = True
158
159
160
161
162
163
164
165
166
167
168
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

169
170
171
    def __init__(
        self,
        model: str,
172
        *,
173
174
        runner: RunnerOption = "auto",
        convert: ConvertOption = "auto",
175
        tokenizer: Optional[str] = None,
176
        tokenizer_mode: TokenizerMode = "auto",
177
        skip_tokenizer_init: bool = False,
178
        trust_remote_code: bool = False,
179
        allowed_local_media_path: str = "",
180
        tensor_parallel_size: int = 1,
181
182
        dtype: ModelDType = "auto",
        quantization: Optional[QuantizationMethods] = None,
183
        revision: Optional[str] = None,
184
        tokenizer_revision: Optional[str] = None,
185
        seed: Optional[int] = None,
186
        gpu_memory_utilization: float = 0.9,
187
        swap_space: float = 4,
188
        cpu_offload_gb: float = 0,
189
        enforce_eager: bool = False,
190
        max_seq_len_to_capture: int = 8192,
191
        disable_custom_all_reduce: bool = False,
192
        disable_async_output_proc: bool = False,
193
        hf_token: Optional[Union[bool, str]] = None,
194
        hf_overrides: Optional[HfOverrides] = None,
195
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
196
        override_pooler_config: Optional[PoolerConfig] = None,
197
198
        compilation_config: Optional[Union[int, dict[str, Any],
                                           CompilationConfig]] = None,
199
200
        **kwargs,
    ) -> None:
201
        """LLM constructor."""
202

203
204
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
205

206
207
208
209
210
211
212
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
        if "kv_transfer_config" in kwargs and isinstance(
                kwargs["kv_transfer_config"], dict):
            from vllm.config import KVTransferConfig
            raw_config_dict = kwargs["kv_transfer_config"]
            try:
                kwargs["kv_transfer_config"] = KVTransferConfig(
                    **raw_config_dict)
            except ValidationError as e:
                logger.error(
                    "Failed to convert 'kv_transfer_config' dict to "
                    "KVTransferConfig object. Dict: %s. Error: %s",
                    raw_config_dict, e)
                # Consider re-raising a more specific vLLM error or ValueError
                # to provide better context to the user.
                raise ValueError(
                    f"Invalid 'kv_transfer_config' provided: {e}") from e

230
231
232
        if hf_overrides is None:
            hf_overrides = {}

233
        if compilation_config is not None:
234
235
236
237
238
239
240
            if isinstance(compilation_config, int):
                compilation_config_instance = CompilationConfig(
                    level=compilation_config)
            elif isinstance(compilation_config, dict):
                predicate = lambda x: is_init_field(CompilationConfig, x[0])
                compilation_config_instance = CompilationConfig(
                    **dict(filter(predicate, compilation_config.items())))
241
242
            else:
                compilation_config_instance = compilation_config
243
        else:
244
            compilation_config_instance = CompilationConfig()
245

Zhuohan Li's avatar
Zhuohan Li committed
246
        engine_args = EngineArgs(
247
            model=model,
248
249
            runner=runner,
            convert=convert,
250
            tokenizer=tokenizer,
251
            tokenizer_mode=tokenizer_mode,
252
            skip_tokenizer_init=skip_tokenizer_init,
253
            trust_remote_code=trust_remote_code,
254
            allowed_local_media_path=allowed_local_media_path,
255
256
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
257
            quantization=quantization,
258
            revision=revision,
259
            tokenizer_revision=tokenizer_revision,
260
261
262
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
263
            cpu_offload_gb=cpu_offload_gb,
264
            enforce_eager=enforce_eager,
265
            max_seq_len_to_capture=max_seq_len_to_capture,
266
            disable_custom_all_reduce=disable_custom_all_reduce,
267
            disable_async_output_proc=disable_async_output_proc,
268
            hf_token=hf_token,
269
            hf_overrides=hf_overrides,
270
            mm_processor_kwargs=mm_processor_kwargs,
271
            override_pooler_config=override_pooler_config,
272
            compilation_config=compilation_config_instance,
273
274
            **kwargs,
        )
275
276
277
278
279

        # Create the Engine (autoselects V0 vs V1)
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args=engine_args, usage_context=UsageContext.LLM_CLASS)
        self.engine_class = type(self.llm_engine)
280

281
        self.request_counter = Counter()
282
        self.default_sampling_params: Union[dict[str, Any], None] = None
283

284
285
286
287
288
289
290
291
292
293
        if envs.VLLM_USE_V1:
            supported_tasks = self.llm_engine \
                .get_supported_tasks()  # type: ignore
        else:
            supported_tasks = self.llm_engine.model_config.supported_tasks

        logger.info("Supported_tasks: %s", supported_tasks)

        self.supported_tasks = supported_tasks

294
295
296
297
298
299
    def get_tokenizer(
        self,
        lora_request: Optional[LoRARequest] = None,
    ) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group().get_lora_tokenizer(
            lora_request)
300
301

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
302
        tokenizer_group = self.llm_engine.get_tokenizer_group()
303

304
305
306
307
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
308
            tokenizer_group.tokenizer = tokenizer
309
        else:
310
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
311

312
    def get_default_sampling_params(self) -> SamplingParams:
313
314
315
316
317
        if self.default_sampling_params is None:
            self.default_sampling_params = (
                self.llm_engine.model_config.get_diff_sampling_param())
        if self.default_sampling_params:
            return SamplingParams.from_optional(**self.default_sampling_params)
318
319
        return SamplingParams()

320
321
322
323
324
325
326
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
327
        *,
328
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
329
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
330
331
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
332
    ) -> list[RequestOutput]:
333
334
        ...

335
    @overload  # LEGACY: single (prompt + optional token ids)
336
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
337
338
339
340
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
341
342
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[int]] = None,
343
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
344
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
345
346
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
347
    ) -> list[RequestOutput]:
348
349
350
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
351
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
352
353
    def generate(
        self,
354
        prompts: list[str],
355
        sampling_params: Optional[Union[SamplingParams,
356
357
                                        list[SamplingParams]]] = None,
        prompt_token_ids: Optional[list[list[int]]] = None,
358
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
359
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
360
361
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
362
    ) -> list[RequestOutput]:
363
364
365
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
366
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
367
368
369
370
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
371
                                        list[SamplingParams]]] = None,
372
        *,
373
        prompt_token_ids: list[int],
374
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
375
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
376
377
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
378
    ) -> list[RequestOutput]:
379
380
381
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
382
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
383
384
    def generate(
        self,
385
        prompts: Optional[list[str]] = None,
386
        sampling_params: Optional[Union[SamplingParams,
387
                                        list[SamplingParams]]] = None,
388
        *,
389
        prompt_token_ids: list[list[int]],
390
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
391
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
392
393
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
394
    ) -> list[RequestOutput]:
395
396
397
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
398
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
399
400
401
402
    def generate(
        self,
        prompts: None,
        sampling_params: None,
403
        prompt_token_ids: Union[list[int], list[list[int]]],
404
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
405
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
406
407
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
408
    ) -> list[RequestOutput]:
409
410
        ...

nunjunj's avatar
nunjunj committed
411
412
413
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
414
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
415
    )
416
417
    def generate(
        self,
418
        prompts: Union[Union[PromptType, Sequence[PromptType]],
419
                       Optional[Union[str, list[str]]]] = None,
420
421
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
422
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
423
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
424
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
425
        guided_options_request: Optional[Union[LLMGuidedOptions,
426
                                               GuidedDecodingRequest]] = None,
427
428
        priority: Optional[list[int]] = None,
    ) -> list[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
429
430
        """Generates the completions for the input prompts.

431
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
432
433
434
435
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
436
            prompts: The prompts to the LLM. You may pass a sequence of prompts
437
                for batch inference. See [PromptType][vllm.inputs.PromptType]
438
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
439
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
440
441
442
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
443
                prompts and it is paired one by one with the prompt.
444
445
446
447
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
448
            lora_request: LoRA request to use for generation, if any.
449
450
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
451
452

        Returns:
453
            A list of `RequestOutput` objects containing the
454
            generated completions in the same order as the input prompts.
455

456
457
458
459
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
460
        """
461
462
463
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
        if runner_type != "generate":
464
465
466
467
            raise ValueError(
                "LLM.generate() is only supported for generative models. "
                "Try passing `--runner generate` to use the model as a "
                "generative model.")
468

469
        if prompt_token_ids is not None:
470
            parsed_prompts = self._convert_v1_inputs(
471
                prompts=cast(Optional[Union[str, list[str]]], prompts),
472
473
474
                prompt_token_ids=prompt_token_ids,
            )
        else:
475
476
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
477

478
479
480
481
482
483
484
485
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

486
487
        if sampling_params is None:
            # Use default sampling params.
488
            sampling_params = self.get_default_sampling_params()
489

490
491
492
493
        tokenization_kwargs: dict[str, Any] = {}
        truncate_prompt_tokens = None
        if isinstance(sampling_params, SamplingParams):
            truncate_prompt_tokens = sampling_params.truncate_prompt_tokens
494
495

        _validate_truncation_size(model_config.max_model_len,
496
497
                                  truncate_prompt_tokens, tokenization_kwargs)

498
499
500
501
        # Add any modality specific loras to the corresponding prompts
        lora_request = self._get_modality_specific_lora_reqs(
            parsed_prompts, lora_request)

502
        self._validate_and_add_requests(
503
            prompts=parsed_prompts,
504
            params=sampling_params,
505
            use_tqdm=use_tqdm,
506
            lora_request=lora_request,
507
            guided_options=guided_options_request,
508
            tokenization_kwargs=tokenization_kwargs,
509
510
            priority=priority,
        )
511

512
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
513
        return self.engine_class.validate_outputs(outputs, RequestOutput)
514

515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
    def _get_modality_specific_lora_reqs(
            self, parsed_prompts: Union[PromptType, Sequence[PromptType]],
            lora_request: Optional[Union[list[LoRARequest], LoRARequest]]):
        # Grab the lora config off the vllm config on the engine,
        # since this is the same for both v0 & v1.
        lora_config = self.llm_engine.vllm_config.lora_config

        # If there's no lora config / default_mm_loras, or the model
        # isn't multimodal, leave the lora as is.
        if (lora_config is None
                or not self.llm_engine.model_config.is_multimodal_model
                or (lora_config and lora_config.default_mm_loras is None)):
            return lora_request

        if not isinstance(parsed_prompts, Sequence):
            parsed_prompts = [parsed_prompts]

        optional_loras = ([lora_request] * len(parsed_prompts)
                          if not isinstance(lora_request, Sequence) else
                          lora_request)

        return [
            self._resolve_single_prompt_mm_lora(
                parsed_prompt,
                opt_lora_req,
                lora_config.default_mm_loras,
            ) for parsed_prompt, opt_lora_req in zip(parsed_prompts,
                                                     optional_loras)
        ]

    def _resolve_single_prompt_mm_lora(self, parsed_prompt: PromptType,
                                       lora_request: Optional[LoRARequest],
                                       default_mm_loras: Optional[dict[str,
                                                                       str]]):
        if (not default_mm_loras or not isinstance(parsed_prompt, dict)
                or "multi_modal_data" not in parsed_prompt):
            return lora_request

        parsed_prompt = cast(Union[TextPrompt, TokensPrompt], parsed_prompt)

        intersection = set(
            parsed_prompt["multi_modal_data"].keys()).intersection(
                default_mm_loras.keys())
        if not intersection:
            return lora_request
        if len(intersection) > 1:
            # TODO: Would be nice to be able to have multiple loras per prompt
            logger.warning(
                "Multiple modality specific loras were registered and would be"
                " used by a single prompt consuming several modalities; "
                " currently we only support one lora per request; as such,"
                " lora(s) registered with modalities: %s"
                " will be skipped", intersection)
            return lora_request

        # Build the LoRA request; the ID of the default mm lora is the
        # index of the modality name sorted alphabetically + 1.
        modality_name = intersection.pop()
        modality_lora_path = default_mm_loras[modality_name]
        modality_lora_id = sorted(default_mm_loras).index(modality_name) + 1

        # If we have a collision, warn if there is a collision,
        # but always send the explicitly provided request.
        if lora_request:
            if lora_request.lora_int_id != modality_lora_id:
                logger.warning(
                    "A modality with a registered lora and a lora_request "
                    "with a different ID were provided; falling back to the "
                    "lora_request as we only apply one LoRARequest per prompt")
            return lora_request

        return LoRARequest(
            modality_name,
            modality_lora_id,
            modality_lora_path,
        )

592
    def collective_rpc(self,
593
                       method: Union[str, Callable[..., _R]],
594
                       timeout: Optional[float] = None,
595
596
                       args: tuple = (),
                       kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
597
598
599
600
601
602
603
604
605
606
607
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
608
                [`TimeoutError`][] on timeout. `None` means wait indefinitely.
609
610
611
612
613
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
614

615
616
617
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
618
        """
619
620

        return self.llm_engine.collective_rpc(method, timeout, args, kwargs)
621
622

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
623
        """
624
625
        Run a function directly on the model inside each worker,
        returning the result for each of them.
626
        """
627
628
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
629

630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
    def _get_beam_search_lora_requests(
        self,
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]],
        prompts: list[Union[TokensPrompt, TextPrompt]],
    ) -> list[Optional[LoRARequest]]:
        """Get the optional lora request corresponding to each prompt."""
        if isinstance(lora_request,
                      Sequence) and len(lora_request) != len(prompts):
            raise ValueError(
                "Lora request list should be the same length as the prompts")

        if lora_request is None or isinstance(lora_request, LoRARequest):
            return [lora_request] * len(prompts)

        raise TypeError(f"Invalid lora_request type {type(lora_request)}")

646
647
    def beam_search(
        self,
648
        prompts: list[Union[TokensPrompt, TextPrompt]],
649
        params: BeamSearchParams,
650
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
651
        use_tqdm: bool = False,
652
    ) -> list[BeamSearchOutput]:
653
654
655
656
657
658
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
659
            params: The beam search parameters.
660
            lora_request: LoRA request to use for generation, if any.
661
            use_tqdm: Whether to use tqdm to display the progress bar.
662
        """
663
664
        # TODO: how does beam search work together with length penalty,
        # frequency, penalty, and stopping criteria, etc.?
665
666
667
668
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
669
670
        length_penalty = params.length_penalty

671
672
673
        lora_requests = self._get_beam_search_lora_requests(
            lora_request, prompts)

674
675
676
677
678
        tokenizer = self.get_tokenizer()
        sort_beams_key = create_sort_beams_key_function(
            tokenizer.eos_token_id,
            length_penalty,
        )
679

680
681
682
683
684
685
686
687
688
689
690
691
        def create_tokens_prompt_from_beam(
                beam: BeamSearchSequence) -> TokensPrompt:
            token_prompt_kwargs: TokensPrompt = {
                "prompt_token_ids": beam.tokens
            }
            if beam.multi_modal_data is not None:
                token_prompt_kwargs["multi_modal_data"] = beam.multi_modal_data

            if beam.mm_processor_kwargs is not None:
                token_prompt_kwargs[
                    "mm_processor_kwargs"] = beam.mm_processor_kwargs
            return TokensPrompt(**token_prompt_kwargs)
692

693
694
695
696
697
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
698
                                            temperature=temperature)
699
        instances: list[BeamSearchInstance] = []
700

701
        for lora_req, prompt in zip(lora_requests, prompts):
702
703
704
705
706
707
708
709
            # Add multimodal processor kwargs & data
            mm_kwargs = {}
            if "multi_modal_data" in prompt:
                mm_kwargs["multi_modal_data"] = prompt["multi_modal_data"]
            if "mm_processor_kwargs" in prompt:
                mm_kwargs["mm_processor_kwargs"] = prompt[
                    "mm_processor_kwargs"]

710
711
            if "prompt_token_ids" in prompt:
                prompt = cast(TokensPrompt, prompt)  # Needed for mypy
712
713
714
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
715

716
            instances.append(
717
718
719
720
721
722
                BeamSearchInstance(
                    prompt_tokens,
                    lora_request=lora_req,
                    logprobs=None,
                    **mm_kwargs,
                ), )
723

724
725
726
727
728
729
730
731
732
733
734
735
        token_iter = range(max_tokens)
        if use_tqdm:
            token_iter = tqdm(token_iter,
                              desc="Beam search",
                              unit="token",
                              unit_scale=False)
            logger.warning(
                "The progress bar shows the upper bound on token steps and "
                "may finish early due to stopping conditions. It does not "
                "reflect instance-level progress.")

        for _ in token_iter:
736
            all_beams: list[BeamSearchSequence] = list(
737
738
739
740
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
741
            instance_start_and_end: list[tuple[int, int]] = list(
742
743
744
745
746
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

747
748
749
750
            # create the corresponding batch entries for prompt & optional lora
            prompts_batch, lora_req_batch = zip(
                *[(create_tokens_prompt_from_beam(beam), beam.lora_request)
                  for beam in all_beams])
751
752
753
754
755

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
756
757
                                   use_tqdm=False,
                                   lora_request=lora_req_batch)
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
774
                                logprobs=current_beam.logprobs + [logprobs],
775
                                lora_request=current_beam.lora_request,
776
                                cum_logprob=current_beam.cum_logprob +
777
778
779
780
                                logprob_obj.logprob,
                                multi_modal_data=current_beam.multi_modal_data,
                                mm_processor_kwargs=current_beam.
                                mm_processor_kwargs)
781
782
783
784
785
786
787

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
788
                                      key=sort_beams_key,
789
790
791
792
793
794
795
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
796
                                      key=sort_beams_key,
797
798
799
800
801
802
803
804
805
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
806
807
    def chat(
        self,
808
809
        messages: Union[list[ChatCompletionMessageParam],
                        list[list[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
810
        sampling_params: Optional[Union[SamplingParams,
811
                                        list[SamplingParams]]] = None,
812
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
nunjunj's avatar
nunjunj committed
813
814
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
815
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
816
        add_generation_prompt: bool = True,
817
        continue_final_message: bool = False,
818
        tools: Optional[list[dict[str, Any]]] = None,
819
        chat_template_kwargs: Optional[dict[str, Any]] = None,
820
821
        mm_processor_kwargs: Optional[dict[str, Any]] = None,
    ) -> list[RequestOutput]:
nunjunj's avatar
nunjunj committed
822
        """
823
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
824

825
        The chat conversation is converted into a text prompt using the
826
        tokenizer and calls the [generate][] method to generate the
827
828
829
830
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
831
832

        Args:
833
834
            messages: A list of conversations or a single conversation.

835
836
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
837

nunjunj's avatar
nunjunj committed
838
839
840
841
842
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
843
844
845
846
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
nunjunj's avatar
nunjunj committed
847
848
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
849
                If not provided, the model's default chat template will be used.
850
851
            chat_template_content_format: The format to render message content.

852
853
854
855
856
                - "string" will render the content as a string.
                  Example: `"Who are you?"`
                - "openai" will render the content as a list of dictionaries,
                  similar to OpenAI schema.
                  Example: `[{"type": "text", "text": "Who are you?"}]`
857

858
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
859
                to each message.
860
            continue_final_message: If True, continues the final message in
861
                the conversation instead of starting a new one. Cannot be
862
                `True` if `add_generation_prompt` is also `True`.
863
864
            chat_template_kwargs: Additional kwargs to pass to the chat
                template.
865
866
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
867
868

        Returns:
869
            A list of `RequestOutput` objects containing the generated
nunjunj's avatar
nunjunj committed
870
871
            responses in the same order as the input messages.
        """
872
        list_of_messages: list[list[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
873

874
875
        # Handle multi and single conversations
        if is_list_of(messages, list):
876
877
            # messages is list[list[...]]
            list_of_messages = cast(list[list[ChatCompletionMessageParam]],
878
                                    messages)
879
        else:
880
            # messages is list[...]
881
            list_of_messages = [
882
                cast(list[ChatCompletionMessageParam], messages)
883
            ]
884

885
        tokenizer = self.get_tokenizer(lora_request)
886
887
888
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
889
            tools,
890
891
            chat_template_content_format,
            tokenizer,
892
            model_config=model_config,
893
894
        )

895
896
897
898
899
900
901
902
        _chat_template_kwargs: dict[str, Any] = dict(
            chat_template=chat_template,
            add_generation_prompt=add_generation_prompt,
            continue_final_message=continue_final_message,
            tools=tools,
        )
        _chat_template_kwargs.update(chat_template_kwargs or {})

903
        prompts: list[Union[TokensPrompt, TextPrompt]] = []
904
905

        for msgs in list_of_messages:
906
907
908
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
909
            conversation, mm_data = parse_chat_messages(
910
911
912
913
914
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
915
916

            if isinstance(tokenizer, MistralTokenizer):
917
                prompt_token_ids = apply_mistral_chat_template(
918
919
                    tokenizer,
                    messages=msgs,
920
                    **_chat_template_kwargs,
921
922
                )
            else:
923
                prompt_str = apply_hf_chat_template(
924
                    tokenizer=tokenizer,
925
                    conversation=conversation,
926
                    model_config=model_config,
927
                    **_chat_template_kwargs,
928
                )
929
930
931
932
                # Special tokens are already included in chat templates so
                # should not be added by the tokenizer in this case.
                prompt_token_ids = tokenizer.encode(prompt_str,
                                                    add_special_tokens=False)
933

934
            prompt = TokensPrompt(prompt_token_ids=prompt_token_ids)
935
936
937
938

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

939
940
941
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

942
            prompts.append(prompt)
943

nunjunj's avatar
nunjunj committed
944
        return self.generate(
945
            prompts,
946
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
947
948
949
950
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

951
952
953
954
955
956
957
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
958
        *,
959
        truncate_prompt_tokens: Optional[int] = None,
960
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
961
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
962
        pooling_task: PoolingTask = "encode",
963
        tokenization_kwargs: Optional[dict[str, Any]] = None,
964
    ) -> list[PoolingRequestOutput]:
965
966
        ...

967
    @overload  # LEGACY: single (prompt + optional token ids)
968
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
969
970
971
972
973
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
974
        prompt_token_ids: Optional[list[int]] = None,
975
        truncate_prompt_tokens: Optional[int] = None,
976
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
977
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
978
        pooling_task: PoolingTask = "encode",
979
        tokenization_kwargs: Optional[dict[str, Any]] = None,
980
    ) -> list[PoolingRequestOutput]:
981
        ...
982

983
    @overload  # LEGACY: multi (prompt + optional token ids)
984
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
985
986
    def encode(
        self,
987
        prompts: list[str],
988
        pooling_params: Optional[Union[PoolingParams,
989
                                       Sequence[PoolingParams]]] = None,
990
        prompt_token_ids: Optional[list[list[int]]] = None,
991
        truncate_prompt_tokens: Optional[int] = None,
992
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
993
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
994
        pooling_task: PoolingTask = "encode",
995
        tokenization_kwargs: Optional[dict[str, Any]] = None,
996
    ) -> list[PoolingRequestOutput]:
997
998
999
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
1000
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1001
1002
1003
1004
1005
1006
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1007
        prompt_token_ids: list[int],
1008
        truncate_prompt_tokens: Optional[int] = None,
1009
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1010
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1011
        pooling_task: PoolingTask = "encode",
1012
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1013
    ) -> list[PoolingRequestOutput]:
1014
1015
1016
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
1017
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1018
1019
    def encode(
        self,
1020
        prompts: Optional[list[str]] = None,
1021
1022
1023
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
1024
        prompt_token_ids: list[list[int]],
1025
        truncate_prompt_tokens: Optional[int] = None,
1026
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1027
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1028
        pooling_task: PoolingTask = "encode",
1029
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1030
    ) -> list[PoolingRequestOutput]:
1031
1032
1033
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
1034
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
1035
1036
1037
1038
    def encode(
        self,
        prompts: None,
        pooling_params: None,
1039
        prompt_token_ids: Union[list[int], list[list[int]]],
1040
        truncate_prompt_tokens: Optional[int] = None,
1041
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1042
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1043
        pooling_task: PoolingTask = "encode",
1044
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1045
    ) -> list[PoolingRequestOutput]:
1046
1047
        ...

nunjunj's avatar
nunjunj committed
1048
1049
1050
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
1051
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
1052
    )
1053
1054
    def encode(
        self,
1055
        prompts: Union[Union[PromptType, Sequence[PromptType]],
1056
                       Optional[Union[str, list[str]]]] = None,
1057
1058
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1059
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]] = None,
1060
        truncate_prompt_tokens: Optional[int] = None,
1061
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1062
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
1063
        pooling_task: PoolingTask = "encode",
1064
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1065
    ) -> list[PoolingRequestOutput]:
1066
1067
        """Apply pooling to the hidden states corresponding to the input
        prompts.
1068

1069
        This class automatically batches the given prompts, considering
1070
1071
1072
1073
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
1074
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1075
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1076
                for more details about the format of each prompts.
1077
1078
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1079
1080
1081
1082
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1083
            lora_request: LoRA request to use for generation, if any.
1084
            pooling_task: Override the pooling task to use.
1085
1086

        Returns:
1087
            A list of `PoolingRequestOutput` objects containing the
1088
            pooled hidden states in the same order as the input prompts.
1089

1090
1091
1092
1093
        Note:
            Using `prompts` and `prompt_token_ids` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the `inputs` parameter.
1094
        """
1095
1096
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1097
        if runner_type != "pooling":
1098
1099
1100
1101
            raise ValueError(
                "LLM.encode() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1102

1103
        if prompt_token_ids is not None:
1104
            parsed_prompts = self._convert_v1_inputs(
1105
                prompts=cast(Optional[Union[str, list[str]]], prompts),
1106
1107
1108
                prompt_token_ids=prompt_token_ids,
            )
        else:
1109
1110
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
1111

1112
1113
1114
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()
1115
1116
1117

        if isinstance(pooling_params, PoolingParams):
            pooling_params.verify(pooling_task, model_config)
1118
1119
        else:
            for pooling_param in pooling_params:
1120
                pooling_param.verify(pooling_task, model_config)
1121

1122
1123
1124
1125
1126
        if tokenization_kwargs is None:
            tokenization_kwargs = dict[str, Any]()
            _validate_truncation_size(model_config.max_model_len,
                                      truncate_prompt_tokens,
                                      tokenization_kwargs)
1127

1128
        self._validate_and_add_requests(
1129
            prompts=parsed_prompts,
1130
            params=pooling_params,
1131
            use_tqdm=use_tqdm,
1132
            lora_request=lora_request,
1133
            tokenization_kwargs=tokenization_kwargs,
1134
1135
        )

1136
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
1137
        return self.engine_class.validate_outputs(outputs,
1138
                                                  PoolingRequestOutput)
1139

1140
1141
1142
1143
1144
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1145
        truncate_prompt_tokens: Optional[int] = None,
1146
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1147
1148
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
1149
1150
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[EmbeddingRequestOutput]:
1151
1152
1153
1154
1155
1156
1157
1158
1159
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1160
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1161
                for more details about the format of each prompts.
1162
1163
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
1164
1165
1166
1167
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1168
1169
1170
            lora_request: LoRA request to use for generation, if any.

        Returns:
1171
            A list of `EmbeddingRequestOutput` objects containing the
1172
1173
            embedding vectors in the same order as the input prompts.
        """
1174
        if "embed" not in self.supported_tasks:
1175
1176
1177
            raise ValueError(
                "Embedding API is not supported by this model. "
                "Try converting the model using `--convert embed`.")
1178

1179
1180
1181
1182
1183
1184
1185
1186
        items = self.encode(
            prompts,
            truncate_prompt_tokens=truncate_prompt_tokens,
            use_tqdm=use_tqdm,
            pooling_params=pooling_params,
            lora_request=lora_request,
            pooling_task="embed",
        )
1187
1188
1189
1190
1191
1192
1193
1194

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
1195
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1196
1197
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ClassificationRequestOutput]:
1198
1199
1200
1201
1202
1203
1204
1205
1206
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
1207
                for batch inference. See [PromptType][vllm.inputs.PromptType]
1208
                for more details about the format of each prompts.
1209
1210
1211
1212
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1213
1214
1215
            lora_request: LoRA request to use for generation, if any.

        Returns:
1216
            A list of `ClassificationRequestOutput` objects containing the
1217
1218
            embedding vectors in the same order as the input prompts.
        """
1219
        if "classify" not in self.supported_tasks:
1220
            raise ValueError(
1221
                "Classification API is not supported by this model. "
1222
                "Try converting the model using `--convert classify`.")
1223

1224
1225
1226
1227
1228
1229
        items = self.encode(
            prompts,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
            pooling_task="classify",
        )
1230
1231
1232

        return [ClassificationRequestOutput.from_base(item) for item in items]

1233
1234
1235
    def _embedding_score(
        self,
        tokenizer: AnyTokenizer,
1236
1237
        text_1: list[Union[str, TextPrompt, TokensPrompt]],
        text_2: list[Union[str, TextPrompt, TokensPrompt]],
1238
        truncate_prompt_tokens: Optional[int] = None,
1239
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1240
1241
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1242

1243
        encoded_output: list[PoolingRequestOutput] = self.encode(
1244
            text_1 + text_2,
1245
            truncate_prompt_tokens=truncate_prompt_tokens,
1246
1247
            use_tqdm=use_tqdm,
            lora_request=lora_request,
1248
1249
            pooling_task="embed",
        )
1250

1251
        encoded_output_1: list[PoolingRequestOutput] = encoded_output[
1252
            0:len(text_1)]
1253
        encoded_output_2: list[PoolingRequestOutput] = encoded_output[
1254
            len(text_1):]
1255
1256
1257
1258

        if len(encoded_output_1) == 1:
            encoded_output_1 = encoded_output_1 * len(encoded_output_2)

1259
1260
1261
        scores = _cosine_similarity(tokenizer=tokenizer,
                                    embed_1=encoded_output_1,
                                    embed_2=encoded_output_2)
1262
1263
1264
1265
1266
1267
1268

        items = self.engine_class.validate_outputs(scores,
                                                   PoolingRequestOutput)
        return [ScoringRequestOutput.from_base(item) for item in items]

    def _cross_encoding_score(
        self,
1269
        tokenizer: AnyTokenizer,
1270
1271
        data_1: Union[list[str], list[ScoreContentPartParam]],
        data_2: Union[list[str], list[ScoreContentPartParam]],
1272
        truncate_prompt_tokens: Optional[int] = None,
1273
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1274
1275
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1276
        model_config = self.llm_engine.model_config
1277
1278
1279

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
1280
                "Score API is not supported for Mistral tokenizer")
1281

1282
1283
        if len(data_1) == 1:
            data_1 = data_1 * len(data_2)
1284

1285
        pooling_params = PoolingParams(task="score")
1286
        tokenization_kwargs: dict[str, Any] = {}
1287
1288

        _validate_truncation_size(model_config.max_model_len,
1289
                                  truncate_prompt_tokens, tokenization_kwargs)
1290
1291
1292

        parsed_prompts = []

1293
1294
        input_pairs = [(t1, t2) for t1, t2 in zip(data_1, data_2)]

1295
        if model_config.is_multimodal_model:
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
            for q, d in input_pairs:
                _, engine_prompt = get_score_prompt(
                    model_config=model_config,
                    data_1=q,
                    data_2=d,
                    tokenizer=tokenizer,
                    tokenization_kwargs=tokenization_kwargs,
                )

                parsed_prompts.append(engine_prompt)
        else:
            for q, t in input_pairs:
1308
                if model_config.use_pad_token:
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
                    # cross_encoder models defaults to using pad_token.
                    prompt_inputs = tokenizer(
                        text=q,  # type: ignore[arg-type]
                        text_pair=t,  # type: ignore[arg-type]
                        **tokenization_kwargs)
                else:
                    # `llm as reranker` models defaults to not using pad_token.
                    prompt_inputs = tokenizer(
                        text=q + t,  # type: ignore[operator]
                        **tokenization_kwargs)
                engine_prompt = TokensPrompt(
                    prompt_token_ids=prompt_inputs["input_ids"],
                    token_type_ids=prompt_inputs.get("token_type_ids"))
                parsed_prompts.append(engine_prompt)
1323
1324
1325
1326

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
1327
            use_tqdm=use_tqdm,
1328
1329
1330
1331
1332
1333
1334
1335
1336
            lora_request=lora_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]

1337
1338
    def score(
        self,
1339
1340
1341
1342
        data_1: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
        data_2: Union[SingletonPrompt, Sequence[SingletonPrompt],
                      ScoreMultiModalParam],
1343
        /,
1344
        *,
1345
        truncate_prompt_tokens: Optional[int] = None,
1346
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1347
1348
        lora_request: Optional[Union[list[LoRARequest], LoRARequest]] = None,
    ) -> list[ScoringRequestOutput]:
1349
1350
        """Generate similarity scores for all pairs `<text,text_pair>` or
          `<multi-modal data, multi-modal data pair>`.
1351

1352
        The inputs can be `1 -> 1`, `1 -> N` or `N -> N`.
1353
1354
        In the `1 - N` case the `data_1` input will be replicated `N`
        times to pair with the `data_2` inputs.
1355
        The input pairs are used to build a list of prompts for the
1356
1357
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
1358
1359
1360
1361
1362
        of your inputs into a single list and pass it to this method.

        Supports both text and multi-modal data (images, etc.) when used with
        appropriate multi-modal models. For multi-modal inputs, ensure the 
        prompt structure matches the model's expected input format.
1363
1364

        Args:
1365
1366
1367
1368
1369
1370
1371
1372
            data_1: Can be a single prompt, a list of prompts or 
                `ScoreMultiModalParam`, which can contain either text or 
                multi-modal data. When a list, it must have the same length as 
                the `data_2` list.
            data_2: The data to pair with the query to form the input to 
                the LLM. Can be text or multi-modal data. See [PromptType]
                [vllm.inputs.PromptType] for more details about the format of 
                each prompt.
1373
1374
1375
1376
            use_tqdm: If `True`, shows a tqdm progress bar.
                If a callable (e.g., `functools.partial(tqdm, leave=False)`),
                it is used to create the progress bar.
                If `False`, no progress bar is created.
1377
1378
1379
            lora_request: LoRA request to use for generation, if any.

        Returns:
1380
            A list of `ScoringRequestOutput` objects containing the
1381
1382
            generated scores in the same order as the input prompts.
        """
1383
1384
        model_config = self.llm_engine.model_config
        runner_type = model_config.runner_type
1385
        if runner_type != "pooling":
1386
1387
1388
1389
            raise ValueError(
                "LLM.score() is only supported for pooling models. "
                "Try passing `--runner pooling` to use the model as a "
                "pooling model.")
1390

1391
1392
        supported_tasks = self.supported_tasks
        if all(t not in supported_tasks for t in ("embed", "classify")):
1393
            raise ValueError("Score API is not supported by this model. "
1394
1395
                             "Try converting the model using "
                             "`--convert embed` or `--convert classify`.")
1396

1397
        if (model_config.is_cross_encoder
1398
                and getattr(model_config.hf_config, "num_labels", 0) != 1):
1399
            raise ValueError("Score API is only enabled for num_labels == 1.")
1400
1401
1402
1403

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
1404
        tokenizer = self.get_tokenizer()
1405

1406
        if not model_config.is_multimodal_model:
1407
1408
1409
1410
1411

            def check_data_type(data: Union[SingletonPrompt,
                                            Sequence[SingletonPrompt],
                                            ScoreMultiModalParam]):
                if isinstance(data, dict) and "content" in data:
1412
1413
                    raise ValueError("ScoreMultiModalParam is not supported "
                                     f"for {model_config.architecture}")
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453

            check_data_type(data_1)
            check_data_type(data_2)

            def ensure_str(prompt: SingletonPrompt):
                if isinstance(prompt, dict):
                    if "multi_modal_data" in prompt:
                        raise ValueError("Multi-modal prompt is not "
                                         "supported for scoring")
                    elif "prompt_token_ids" in prompt:
                        prompt = tokenizer.decode(
                            cast(TokensPrompt, prompt)["prompt_token_ids"])
                    elif "prompt" in prompt:
                        prompt = cast(TextPrompt, prompt)["prompt"]
                assert type(prompt) is str
                return prompt

            if isinstance(data_1, (str, dict)):
                # Convert a single prompt to a list.
                data_1 = [data_1]  # type: ignore[list-item]

            data_1 = [ensure_str(t) for t in data_1]

            if isinstance(data_2, (str, dict)):
                # Convert a single prompt to a list.
                data_2 = [data_2]  # type: ignore[list-item]

            data_2 = [ensure_str(t) for t in data_2]

        if isinstance(data_1, dict) and "content" in data_1:
            data_1 = data_1.get("content")  # type: ignore[assignment]
        elif isinstance(data_1, str):
            data_1 = [data_1]

        if isinstance(data_2, dict) and "content" in data_2:
            data_2 = data_2.get("content")  # type: ignore[assignment]
        elif isinstance(data_2, str):
            data_2 = [data_2]

        _validate_score_input_lens(data_1, data_2)  # type: ignore[arg-type]
1454

1455
        if model_config.is_cross_encoder:
1456
1457
1458
1459
1460
1461
            return self._cross_encoding_score(
                tokenizer,
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
                truncate_prompt_tokens,
                use_tqdm,
1462
                lora_request)
1463
        else:
1464
1465
            return self._embedding_score(
                tokenizer,
1466
1467
                data_1,  # type: ignore[arg-type]
                data_2,  # type: ignore[arg-type]
1468
1469
                truncate_prompt_tokens,
                use_tqdm,
1470
                lora_request)
1471

1472
1473
1474
1475
1476
1477
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1478
1479
    def reset_prefix_cache(self, device: Optional[Device] = None) -> bool:
        return self.llm_engine.reset_prefix_cache(device)
1480

1481
1482
1483
1484
1485
1486
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

1487
        Args:
1488
1489
            level: The sleep level. Level 1 sleep will offload the model
                weights and discard the kv cache. The content of kv cache
1490
                is forgotten. Level 1 sleep is good for sleeping and waking
1491
1492
1493
1494
1495
                up the engine to run the same model again. The model weights
                are backed up in CPU memory. Please make sure there's enough
                CPU memory to store the model weights. Level 2 sleep will
                discard both the model weights and the kv cache. The content
                of both the model weights and kv cache is forgotten. Level 2
1496
                sleep is good for sleeping and waking up the engine to run a
1497
                different model or update the model, where previous model
1498
                weights are not needed. It reduces CPU memory pressure.
1499
        """
1500
        self.reset_prefix_cache()
1501
1502
        self.llm_engine.sleep(level=level)

1503
    def wake_up(self, tags: Optional[list[str]] = None):
1504
        """
1505
        Wake up the engine from sleep mode. See the [sleep][] method
1506
        for more details.
1507

1508
        Args:
1509
1510
            tags: An optional list of tags to reallocate the engine memory
                for specific memory allocations. Values must be in
1511
                `("weights", "kv_cache")`. If None, all memory is reallocated.
1512
                wake_up should be called with all tags (or None) before the
1513
1514
1515
                engine is used again.
        """
        self.llm_engine.wake_up(tags)
1516

1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
    def get_metrics(self) -> list["Metric"]:
        """Return a snapshot of aggregated metrics from Prometheus.

        Returns:
            A ``MetricSnapshot`` instance capturing the current state
            of all aggregated metrics from Prometheus.

        Note:
            This method is only available with the V1 LLM engine.
        """
        from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
        assert isinstance(self.llm_engine, V1LLMEngine)
        return self.llm_engine.get_metrics()

1531
1532
    # LEGACY
    def _convert_v1_inputs(
1533
        self,
1534
1535
        prompts: Optional[Union[str, list[str]]],
        prompt_token_ids: Optional[Union[list[int], list[list[int]]]],
1536
1537
    ):
        # skip_tokenizer_init is now checked in engine
1538

1539
1540
1541
1542
1543
1544
1545
1546
1547
        if prompts is None and prompt_token_ids is None:
            raise ValueError(
                "Either prompts or prompt_token_ids must be provided.")
        if prompts is not None and prompt_token_ids is not None \
                and len(prompts) != len(prompt_token_ids):
            raise ValueError(
                "The lengths of prompts and prompt_token_ids must be the same."
            )

1548
1549
1550
1551
1552
1553
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1554
1555
        if prompts is not None:
            num_requests = len(prompts)
1556
        elif prompt_token_ids is not None:
1557
            num_requests = len(prompt_token_ids)
1558
        parsed_prompts: list[PromptType] = []
1559
        for i in range(num_requests):
1560
            item: PromptType
1561

1562
            if prompts is not None:
1563
1564
1565
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1566
            else:
1567
                raise AssertionError
1568

1569
            parsed_prompts.append(item)
1570

1571
        return parsed_prompts
1572
1573
1574

    def _validate_and_add_requests(
        self,
1575
        prompts: Union[PromptType, Sequence[PromptType]],
1576
1577
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1578
        *,
1579
        use_tqdm: Union[bool, Callable[..., tqdm]] = True,
1580
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1581
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1582
        guided_options: Optional[GuidedDecodingRequest] = None,
1583
        priority: Optional[list[int]] = None,
1584
    ) -> None:
1585
1586
1587
1588
1589
1590
1591
1592
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1593
        if isinstance(prompts, (str, dict)):
1594
            # Convert a single prompt to a list.
1595
            prompts = [prompts]
1596

1597
        num_requests = len(prompts)
1598
        if isinstance(params, Sequence) and len(params) != num_requests:
1599
            raise ValueError("The lengths of prompts and params "
1600
                             "must be the same.")
1601
        if isinstance(lora_request,
1602
                      Sequence) and len(lora_request) != num_requests:
1603
1604
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1605

1606
        for sp in params if isinstance(params, Sequence) else (params, ):
1607
            if isinstance(sp, SamplingParams):
1608
                self._add_guided_params(sp, guided_options)
1609
1610
1611

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1612

Zhuohan Li's avatar
Zhuohan Li committed
1613
        # Add requests to the engine.
1614
1615
        it = prompts
        if use_tqdm:
1616
1617
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            it = tqdm_func(it, desc="Adding requests")
1618
1619

        for i, prompt in enumerate(it):
1620
            self._add_request(
1621
                prompt,
1622
                params[i] if isinstance(params, Sequence) else params,
1623
                tokenization_kwargs=tokenization_kwargs,
1624
1625
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
1626
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1627
            )
1628

1629
    def _add_request(
nunjunj's avatar
nunjunj committed
1630
        self,
1631
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1632
        params: Union[SamplingParams, PoolingParams],
1633
        tokenization_kwargs: Optional[dict[str, Any]] = None,
1634
        lora_request: Optional[LoRARequest] = None,
1635
        priority: int = 0,
1636
1637
    ) -> None:
        request_id = str(next(self.request_counter))
1638
1639
        self.llm_engine.add_request(
            request_id,
1640
            prompt,
1641
1642
            params,
            lora_request=lora_request,
1643
            tokenization_kwargs=tokenization_kwargs,
1644
            priority=priority,
nunjunj's avatar
nunjunj committed
1645
        )
1646

1647
    def _add_guided_params(
1648
1649
1650
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1651
1652
1653
1654
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
1655
            raise ValueError("Cannot set both guided_options_request and "
1656
1657
1658
1659
1660
1661
1662
1663
1664
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
1665
1666
1667
            whitespace_pattern=guided_options.guided_whitespace_pattern,
            structural_tag=guided_options.structural_tag,
        )
1668
1669
        return params

1670
    def _run_engine(
1671
1672
1673
        self,
        *,
        use_tqdm: Union[bool, Callable[..., tqdm]] = True
1674
    ) -> list[Union[RequestOutput, PoolingRequestOutput]]:
1675
1676
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1677
            num_requests = self.llm_engine.get_num_unfinished_requests()
1678
1679
            tqdm_func = use_tqdm if callable(use_tqdm) else tqdm
            pbar = tqdm_func(
1680
1681
1682
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1683
1684
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1685
            )
1686

Zhuohan Li's avatar
Zhuohan Li committed
1687
        # Run the engine.
1688
        outputs: list[Union[RequestOutput, PoolingRequestOutput]] = []
1689
1690
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1691
1692
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1693
            for output in step_outputs:
1694
                if output.finished:
1695
1696
                    outputs.append(output)
                    if use_tqdm:
1697
1698
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1699
                            n = len(output.outputs)
1700
                            assert output.prompt_token_ids is not None
1701
                            total_in_toks += len(output.prompt_token_ids) * n
1702
1703
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1704
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1705
1706
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1707
1708
1709
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1710
                            pbar.update(n)
1711
1712
                        else:
                            pbar.update(1)
1713
1714
                        if pbar.n == num_requests:
                            pbar.refresh()
1715

1716
1717
        if use_tqdm:
            pbar.close()
1718
1719
1720
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1721
        return sorted(outputs, key=lambda x: int(x.request_id))