llm.py 48.2 KB
Newer Older
1
import itertools
2
import json
3
import warnings
4
from contextlib import contextmanager
Joe Runde's avatar
Joe Runde committed
5
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
6
                    Union, cast, overload)
7

8
from tqdm import tqdm
9
from typing_extensions import deprecated
10

11
from vllm import envs
12
13
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
14
from vllm.config import CompilationConfig
15
16
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
17
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
18
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
19
                                         ChatTemplateContentFormatOption,
20
21
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
22
23
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
24
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
25
from vllm.inputs.parse import parse_and_batch_prompt
26
from vllm.logger import init_logger
27
from vllm.lora.request import LoRARequest
28
29
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
30
from vllm.outputs import PoolingRequestOutput, RequestOutput
31
from vllm.pooling_params import PoolingParams
32
from vllm.prompt_adapter.request import PromptAdapterRequest
33
34
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
35
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
36
37
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
38
from vllm.usage.usage_lib import UsageContext
39
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
40

41
42
logger = init_logger(__name__)

43
44

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
45
46
47
48
49
50
51
52
53
54
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
55
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
56
57
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
58
59
60
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
61
62
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
63
64
65
66
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
70
71
72
73
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
74
        quantization: The method used to quantize the model weights. Currently,
75
            we support "awq", "gptq", and "fp8" (experimental).
76
77
78
79
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
80
81
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
82
83
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
84
85
86
87
88
89
90
91
92
93
94
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
95
96
97
98
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
99
100
101
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
102
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
103
            When a sequence has context length larger than this, we fall back
104
105
106
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
107
108
109
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
110
111
112
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
113
114
115
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
116
117
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
118

119
120
121
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
122
    """
123

124
125
126
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

127
128
129
130
131
132
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

133
134
135
136
137
138
139
140
141
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

142
143
144
145
146
147
148
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
149
150
151
    def __init__(
        self,
        model: str,
152
        tokenizer: Optional[str] = None,
153
        tokenizer_mode: str = "auto",
154
        skip_tokenizer_init: bool = False,
155
        trust_remote_code: bool = False,
156
        allowed_local_media_path: str = "",
157
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
158
        dtype: str = "auto",
159
        quantization: Optional[str] = None,
160
        revision: Optional[str] = None,
161
        tokenizer_revision: Optional[str] = None,
162
163
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
164
        swap_space: float = 4,
165
        cpu_offload_gb: float = 0,
166
        enforce_eager: Optional[bool] = None,
167
        max_seq_len_to_capture: int = 8192,
168
        disable_custom_all_reduce: bool = False,
169
        disable_async_output_proc: bool = False,
170
        hf_overrides: Optional[HfOverrides] = None,
171
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
172
173
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
174
        override_pooler_config: Optional[PoolerConfig] = None,
175
        compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
176
177
        **kwargs,
    ) -> None:
178
179
180
181
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
182
        it defaults to False.
183
184
        '''

185
186
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
187

188
        if compilation_config is not None:
189
190
191
192
193
194
195
196
            if isinstance(compilation_config, (int)):
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            elif isinstance(compilation_config, (dict)):
                compilation_config_instance = CompilationConfig.from_cli(
                    json.dumps(compilation_config))
            else:
                compilation_config_instance = compilation_config
197
198
199
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
200
        engine_args = EngineArgs(
201
            model=model,
202
            task=task,
203
            tokenizer=tokenizer,
204
            tokenizer_mode=tokenizer_mode,
205
            skip_tokenizer_init=skip_tokenizer_init,
206
            trust_remote_code=trust_remote_code,
207
            allowed_local_media_path=allowed_local_media_path,
208
209
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
210
            quantization=quantization,
211
            revision=revision,
212
            tokenizer_revision=tokenizer_revision,
213
214
215
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
216
            cpu_offload_gb=cpu_offload_gb,
217
            enforce_eager=enforce_eager,
218
            max_seq_len_to_capture=max_seq_len_to_capture,
219
            disable_custom_all_reduce=disable_custom_all_reduce,
220
            disable_async_output_proc=disable_async_output_proc,
221
            hf_overrides=hf_overrides,
222
            mm_processor_kwargs=mm_processor_kwargs,
223
            override_pooler_config=override_pooler_config,
224
            compilation_config=compilation_config_instance,
225
226
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
227
228
229
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
230
231

        # TODO(rob): enable mp by default (issue with fork vs spawn)
Joe Runde's avatar
Joe Runde committed
232
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
233
            engine_args, usage_context=UsageContext.LLM_CLASS)
234

235
236
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
237
238
239
240
241
242
243
244
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

245
246
247
248
249
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
250

251
252
253
254
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
255
            tokenizer_group.tokenizer = tokenizer
256
        else:
257
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
258

259
    @overload  # LEGACY: single (prompt + optional token ids)
260
    @deprecated("'prompt_token_ids' will become part of 'prompts")
261
262
263
264
265
266
267
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
268
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
269
270
271
272
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
273
    @deprecated("'prompt_token_ids' will become part of 'prompts")
274
275
    def generate(
        self,
276
        prompts: List[str],
277
278
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
279
        prompt_token_ids: Optional[List[List[int]]] = None,
280
        use_tqdm: bool = True,
281
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
282
283
284
285
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
286
    @deprecated("'prompt_token_ids' will become part of 'prompts")
287
288
289
290
291
292
293
294
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
295
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
296
297
298
299
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
300
    @deprecated("'prompt_token_ids' will become part of 'prompts")
301
302
303
304
305
306
307
308
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
309
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
310
311
312
313
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
314
    @deprecated("'prompt_token_ids' will become part of 'prompts")
315
316
317
318
319
320
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
321
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
322
323
324
325
326
327
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
328
329
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
330
331
332
333
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
334
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
335
336
337
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
338
339
340
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
341
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
342
    )
343
344
    def generate(
        self,
345
        prompts: Union[Union[PromptType, Sequence[PromptType]],
346
347
348
349
350
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
351
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
352
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
353
        guided_options_request: Optional[Union[LLMGuidedOptions,
354
355
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
356
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
357
358
        """Generates the completions for the input prompts.

359
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
360
361
362
363
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
364
365
366
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
367
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
368
369
370
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
371
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
372
            use_tqdm: Whether to use tqdm to display the progress bar.
373
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
374
            prompt_adapter_request: Prompt Adapter request to use for
375
                generation, if any.
376
377
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
378
379

        Returns:
nunjunj's avatar
nunjunj committed
380
            A list of ``RequestOutput`` objects containing the
381
            generated completions in the same order as the input prompts.
382
383
384
385
386

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
387
        """
388
389
390
        task = self.llm_engine.model_config.task
        if task != "generate":
            messages = [
391
                "LLM.generate() is only supported for (conditional) generation "
392
393
394
395
396
397
398
399
400
401
402
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "generate" in supported_tasks:
                messages.append(
                    "Your model supports the 'generate' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task generate`.")

            raise ValueError(" ".join(messages))
403

404
        if prompt_token_ids is not None:
405
            parsed_prompts = self._convert_v1_inputs(
406
407
408
409
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
410
411
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
412

413
414
415
416
417
418
419
420
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

421
422
423
424
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

425
        self._validate_and_add_requests(
426
            prompts=parsed_prompts,
427
428
            params=sampling_params,
            lora_request=lora_request,
429
            prompt_adapter_request=prompt_adapter_request,
430
431
            guided_options=guided_options_request,
            priority=priority)
432

433
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
434
        return self.engine_class.validate_outputs(outputs, RequestOutput)
435

436
437
438
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
439
        params: BeamSearchParams,
440
441
442
443
444
445
446
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
447
448
            params: The beam search parameters.

449
450
451
452
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

453
454
455
456
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
457
458
459
460
461
462
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
463

464
465
466
467
468
469
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
470
                                            temperature=temperature)
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
516
                                logprobs=current_beam.logprobs + [logprobs],
517
518
519
520
521
522
523
524
525
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
526
                                      key=sort_beams_key,
527
528
529
530
531
532
533
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
534
                                      key=sort_beams_key,
535
536
537
538
539
540
541
542
543
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
544
545
    def chat(
        self,
546
547
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
548
549
550
551
552
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
553
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
554
        add_generation_prompt: bool = True,
555
        continue_final_message: bool = False,
556
        tools: Optional[List[Dict[str, Any]]] = None,
557
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
558
559
    ) -> List[RequestOutput]:
        """
560
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
561

562
563
564
565
566
567
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
568
569

        Args:
570
571
572
573
574
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
575
576
577
578
579
580
581
582
583
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
584
585
586
587
588
589
590
591
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

592
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
593
                to each message.
594
            continue_final_message: If True, continues the final message in
595
596
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
597
598
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
599
600
601
602
603

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
604
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
605

606
607
608
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
609
610
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
611
        else:
612
            # messages is List[...]
613
614
615
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
616

617
618
619
620
621
622
623
624
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

625
626
627
        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
628
629
630
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
631
            conversation, mm_data = parse_chat_messages(
632
633
634
635
636
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
637
638
639
640
641
642
643
644

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
645
                    continue_final_message=continue_final_message,
646
647
648
649
650
651
652
653
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
654
                    continue_final_message=continue_final_message,
655
656
657
658
659
660
661
662
663
664
665
666
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

667
668
669
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

670
            prompts.append(prompt)
671

nunjunj's avatar
nunjunj committed
672
        return self.generate(
673
            prompts,
674
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
675
676
677
678
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

679
    @overload  # LEGACY: single (prompt + optional token ids)
680
    @deprecated("'prompt_token_ids' will become part of 'prompts")
681
682
683
684
685
686
687
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
688
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
689
    ) -> List[PoolingRequestOutput]:
690
        ...
691

692
    @overload  # LEGACY: multi (prompt + optional token ids)
693
    @deprecated("'prompt_token_ids' will become part of 'prompts")
694
695
    def encode(
        self,
696
        prompts: List[str],
697
        pooling_params: Optional[Union[PoolingParams,
698
                                       Sequence[PoolingParams]]] = None,
699
700
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
701
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
702
    ) -> List[PoolingRequestOutput]:
703
704
705
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
706
    @deprecated("'prompt_token_ids' will become part of 'prompts")
707
708
709
710
711
712
713
714
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
715
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
716
    ) -> List[PoolingRequestOutput]:
717
718
719
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
720
    @deprecated("'prompt_token_ids' will become part of 'prompts")
721
722
723
724
725
726
727
728
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
729
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
730
    ) -> List[PoolingRequestOutput]:
731
732
733
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
734
    @deprecated("'prompt_token_ids' will become part of 'prompts")
735
736
737
738
739
740
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
741
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
742
    ) -> List[PoolingRequestOutput]:
743
744
745
746
747
        ...

    @overload
    def encode(
        self,
748
749
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
750
751
752
753
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
754
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
755
    ) -> List[PoolingRequestOutput]:
756
757
        ...

nunjunj's avatar
nunjunj committed
758
759
760
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
761
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
762
    )
763
764
    def encode(
        self,
765
        prompts: Union[Union[PromptType, Sequence[PromptType]],
766
767
768
769
770
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
771
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
772
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
773
    ) -> List[PoolingRequestOutput]:
774
775
        """Generates the completions for the input prompts.

776
        This class automatically batches the given prompts, considering
777
778
779
780
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
781
782
783
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
784
785
786
787
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
788
            prompt_adapter_request: Prompt Adapter request to use for
789
                generation, if any.
790
791

        Returns:
792
            A list of ``PoolingRequestOutput`` objects containing the
793
            generated embeddings in the same order as the input prompts.
794
795
796
797
798

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
799
        """
800
801
802
803
804
805
806
807
808
809
810
811
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.encode() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))
812

813
        if prompt_token_ids is not None:
814
            parsed_prompts = self._convert_v1_inputs(
815
816
817
818
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
819
820
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
821

822
823
824
825
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

826
        self._validate_and_add_requests(
827
            prompts=parsed_prompts,
828
829
            params=pooling_params,
            lora_request=lora_request,
830
            prompt_adapter_request=prompt_adapter_request,
831
832
        )

833
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
834
        return self.engine_class.validate_outputs(outputs,
835
                                                  PoolingRequestOutput)
836

837
838
839
840
841
842
843
844
845
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
846
    ) -> List[PoolingRequestOutput]:
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
        """Generates similarity scores for all pairs <text,text_pair>.

        The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
        the text_1 sentence will be replicated N times to pair with the text_2
        sentences. The input pairs are used to build a list of prompts for the
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
                case it has to have the same length as the text_2 list
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
868
            A list of ``PoolingRequestOutput`` objects containing the
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
            generated scores in the same order as the input prompts.
        """
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.score() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))

        if not self.llm_engine.model_config.is_cross_encoder:
            raise ValueError("Your model does not support the cross encoding")

        tokenizer = self.llm_engine.get_tokenizer()

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "MistralTokenizer not supported for cross-encoding")

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
                                     "supported for cross encoding")
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
        text_1 = [ensure_str(t) for t in text_1]

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
        text_2 = [ensure_str(t) for t in text_2]

        if len(text_1) > 1 and len(text_1) != len(text_2):
            raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
        if len(text_1) == 0:
            raise ValueError("At least one text element must be given")
        if len(text_2) == 0:
            raise ValueError("At least one text_pair element must be given")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
        pooling_params = PoolingParams()

        tokenization_kwargs: Dict[str, Any] = {}
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        return self.engine_class.validate_outputs(outputs,
957
                                                  PoolingRequestOutput)
958

959
960
961
962
963
964
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

965
966
    # LEGACY
    def _convert_v1_inputs(
967
968
        self,
        prompts: Optional[Union[str, List[str]]],
969
970
971
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
972

973
974
975
976
977
978
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
979

980
        num_requests = None
981
982
        if prompts is not None:
            num_requests = len(prompts)
983
984
985
986
987
988
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

989
            num_requests = len(prompt_token_ids)
990
991
992
993
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

994
        parsed_prompts: List[PromptType] = []
995
        for i in range(num_requests):
996
            item: PromptType
997

998
            if prompts is not None:
999
1000
1001
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1002
            else:
1003
                raise AssertionError
1004

1005
            parsed_prompts.append(item)
1006

1007
        return parsed_prompts
1008
1009
1010

    def _validate_and_add_requests(
        self,
1011
        prompts: Union[PromptType, Sequence[PromptType]],
1012
1013
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1014
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1015
        prompt_adapter_request: Optional[PromptAdapterRequest],
1016
        guided_options: Optional[GuidedDecodingRequest] = None,
1017
        priority: Optional[List[int]] = None,
1018
    ) -> None:
1019
1020
1021
1022
1023
1024
1025
1026
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1027
        if isinstance(prompts, (str, dict)):
1028
            # Convert a single prompt to a list.
1029
            prompts = [prompts]
1030

1031
        num_requests = len(prompts)
1032
1033
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1034
                             "must be the same.")
1035
1036
1037
1038
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1039

1040
1041
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1042
                self._add_guided_params(sp, guided_options)
1043
1044
1045

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1046

Zhuohan Li's avatar
Zhuohan Li committed
1047
        # Add requests to the engine.
1048
        for i, prompt in enumerate(prompts):
1049
            self._add_request(
1050
                prompt,
1051
                params[i] if isinstance(params, Sequence) else params,
1052
1053
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1054
                prompt_adapter_request=prompt_adapter_request,
1055
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1056
            )
1057

1058
    def _add_request(
nunjunj's avatar
nunjunj committed
1059
        self,
1060
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1061
        params: Union[SamplingParams, PoolingParams],
1062
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1063
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1064
        priority: int = 0,
1065
1066
    ) -> None:
        request_id = str(next(self.request_counter))
1067
1068
        self.llm_engine.add_request(
            request_id,
1069
            prompt,
1070
1071
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1072
            prompt_adapter_request=prompt_adapter_request,
1073
            priority=priority,
nunjunj's avatar
nunjunj committed
1074
        )
1075

1076
    def _add_guided_params(
1077
1078
1079
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1095
1096
        return params

1097
    def _run_engine(
1098
            self, *, use_tqdm: bool
1099
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1100
1101
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1102
            num_requests = self.llm_engine.get_num_unfinished_requests()
1103
1104
1105
1106
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1107
1108
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1109
            )
1110

Zhuohan Li's avatar
Zhuohan Li committed
1111
        # Run the engine.
1112
        outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
1113
1114
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1115
1116
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1117
            for output in step_outputs:
1118
                if output.finished:
1119
1120
                    outputs.append(output)
                    if use_tqdm:
1121
1122
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1123
                            assert output.prompt_token_ids is not None
1124
1125
1126
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1127
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1128
1129
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1130
1131
1132
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1133
                        pbar.update(1)
1134

1135
1136
        if use_tqdm:
            pbar.close()
1137
1138
1139
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1140
        return sorted(outputs, key=lambda x: int(x.request_id))