llm.py 24.8 KB
Newer Older
1
2
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
3
4

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
9
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
10
                         parse_and_batch_prompt)
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
14
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
15
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.sampling_params import SamplingParams
17
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
yhu422's avatar
yhu422 committed
18
from vllm.usage.usage_lib import UsageContext
19
from vllm.utils import Counter, deprecate_kwargs
20

21
22
logger = init_logger(__name__)

23
24

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
30
31
32
33
34
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
35
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
36
37
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
38
39
40
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
41
42
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
50
        quantization: The method used to quantize the model weights. Currently,
51
52
53
54
55
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
56
57
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
58
59
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
60
61
62
63
64
65
66
67
68
69
70
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
71
72
73
74
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
75
76
77
78
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
79
80
81
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
82
83
            When a sequence has context length larger than this, we fall back
            to eager mode.
84
        disable_custom_all_reduce: See ParallelConfig
85
86
87
88
89
90
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
    
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
91
    """
92

93
94
95
96
97
98
99
100
101
102
103
104
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

105
106
107
    def __init__(
        self,
        model: str,
108
        tokenizer: Optional[str] = None,
109
        tokenizer_mode: str = "auto",
110
        skip_tokenizer_init: bool = False,
111
        trust_remote_code: bool = False,
112
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
113
        dtype: str = "auto",
114
        quantization: Optional[str] = None,
115
        revision: Optional[str] = None,
116
        tokenizer_revision: Optional[str] = None,
117
118
119
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
120
        cpu_offload_gb: float = 0,
121
        enforce_eager: bool = False,
122
123
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
124
        disable_custom_all_reduce: bool = False,
125
126
127
128
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
129
130
131
132
133
        removed_vision_keys = ("image_token_id", "image_feature_size",
                               "image_input_shape", "image_input_type")
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
134
        engine_args = EngineArgs(
135
            model=model,
136
            tokenizer=tokenizer,
137
            tokenizer_mode=tokenizer_mode,
138
            skip_tokenizer_init=skip_tokenizer_init,
139
            trust_remote_code=trust_remote_code,
140
141
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
142
            quantization=quantization,
143
            revision=revision,
144
            tokenizer_revision=tokenizer_revision,
145
146
147
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
148
            cpu_offload_gb=cpu_offload_gb,
149
150
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
151
            max_seq_len_to_capture=max_seq_len_to_capture,
152
            disable_custom_all_reduce=disable_custom_all_reduce,
153
154
            **kwargs,
        )
yhu422's avatar
yhu422 committed
155
156
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
157
158
        self.request_counter = Counter()

159
    def get_tokenizer(
160
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
161
        return self.llm_engine.tokenizer.tokenizer
162

163
164
165
166
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
167
168
169
170
171
172
173
174
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
            self.llm_engine.tokenizer.tokenizer = tokenizer
        else:
            self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
                tokenizer)
175

176
177
178
179
180
181
182
183
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
184
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
185
186
187
188
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
189
190
    def generate(
        self,
191
        prompts: List[str],
192
193
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
194
        prompt_token_ids: Optional[List[List[int]]] = None,
195
        use_tqdm: bool = True,
196
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
197
198
199
200
201
202
203
204
205
206
207
208
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
209
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
210
211
212
213
214
215
216
217
218
219
220
221
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
222
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
223
224
225
226
227
228
229
230
231
232
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
233
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
234
235
236
237
238
239
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
240
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
241
242
243
244
245
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
246
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
247
248
249
250
251
252
253
254
255
256
    ) -> List[RequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def generate(
        self,
257
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
258
259
260
261
262
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
263
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
264
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
265
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
266
267
        """Generates the completions for the input prompts.

268
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270
271
272
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
273
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
274
            sampling_params: The sampling parameters for text generation. If
275
276
277
278
                None, we use the default sampling parameters. 
                When it is a single value, it is applied to every prompt. 
                When it is a list, the list must have the same length as the 
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
279
            use_tqdm: Whether to use tqdm to display the progress bar.
280
            lora_request: LoRA request to use for generation, if any.
281
282
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284

        Returns:
285
286
            A list of `RequestOutput` objects containing the
            generated completions in the same order as the input prompts.
287
288
289
290
291

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
292
        """
293
294
295
296
297
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.generate() is only supported for generation models "
                "(XForCausalLM).")

298
        if prompt_token_ids is not None:
299
300
301
302
303
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
304
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
305

306
307
308
309
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

310
311
312
313
        self._validate_and_add_requests(
            inputs=inputs,
            params=sampling_params,
            lora_request=lora_request,
314
            prompt_adapter_request=prompt_adapter_request)
315

316
317
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
318

319
320
321
322
323
324
325
326
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
327
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
328
329
    ) -> List[EmbeddingRequestOutput]:
        ...
330

331
    @overload  # LEGACY: multi (prompt + optional token ids)
332
333
    def encode(
        self,
334
        prompts: List[str],
335
        pooling_params: Optional[Union[PoolingParams,
336
                                       Sequence[PoolingParams]]] = None,
337
338
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
339
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
340
341
342
343
344
345
346
347
348
349
350
351
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
352
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
353
354
355
356
357
358
359
360
361
362
363
364
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
365
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
366
367
368
369
370
371
372
373
374
375
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
376
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
377
378
379
380
381
382
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
383
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
384
385
386
387
388
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
389
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
390
391
392
393
394
395
396
397
398
399
    ) -> List[EmbeddingRequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def encode(
        self,
400
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
401
402
403
404
405
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
406
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
407
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
408
409
410
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

411
        This class automatically batches the given prompts, considering
412
413
414
415
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
416
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
417
                batch inference. See :class:`~vllm.inputs.PromptInputs`
418
                for more details about the format of each input.
419
420
421
422
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
423
424
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
425
426
427
428

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
429
430
431
432
433

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
434
        """
435
436
437
438
439
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

440
        if prompt_token_ids is not None:
441
442
443
444
445
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
446
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
447

448
449
450
451
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

452
453
454
455
        self._validate_and_add_requests(
            inputs=inputs,
            params=pooling_params,
            lora_request=lora_request,
456
            prompt_adapter_request=prompt_adapter_request,
457
458
        )

459
460
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
461

462
463
    # LEGACY
    def _convert_v1_inputs(
464
465
        self,
        prompts: Optional[Union[str, List[str]]],
466
467
468
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
469

470
471
472
473
474
475
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
476

477
        num_requests = None
478
479
        if prompts is not None:
            num_requests = len(prompts)
480
481
482
483
484
485
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

486
            num_requests = len(prompt_token_ids)
487
488
489
490
491
492
493
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

        inputs: List[PromptInputs] = []
        for i in range(num_requests):
            if prompts is not None:
494
495
496
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
497
            else:
498
                raise AssertionError
499
500
501
502
503
504
505

            inputs.append(item)

        return inputs

    def _validate_and_add_requests(
        self,
506
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
507
508
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
509
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
510
        prompt_adapter_request: Optional[PromptAdapterRequest],
511
512
513
514
515
516
    ) -> None:
        if isinstance(inputs, (str, dict)):
            # Convert a single prompt to a list.
            inputs = [inputs]

        num_requests = len(inputs)
517

518
519
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
520
                             "must be the same.")
521
522
523
524
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
525

Zhuohan Li's avatar
Zhuohan Li committed
526
        # Add requests to the engine.
527
528
529
530
        for i, request_inputs in enumerate(inputs):
            self._add_request(
                request_inputs,
                params[i] if isinstance(params, Sequence) else params,
531
532
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
533
                prompt_adapter_request=prompt_adapter_request)
534

535
    def _add_request(
536
537
538
539
540
541
            self,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            lora_request: Optional[Union[List[LoRARequest],
                                         LoRARequest]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
542
543
    ) -> None:
        request_id = str(next(self.request_counter))
544
545
546
547
548
549
        self.llm_engine.add_request(
            request_id,
            inputs,
            params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
550

551
    def _run_engine(
552
            self, *, use_tqdm: bool
553
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
554
555
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
556
            num_requests = self.llm_engine.get_num_unfinished_requests()
557
558
559
560
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
561
562
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
563
            )
Zhuohan Li's avatar
Zhuohan Li committed
564
        # Run the engine.
565
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
566
567
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
568
569
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
570
            for output in step_outputs:
571
                if output.finished:
572
573
                    outputs.append(output)
                    if use_tqdm:
574
575
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
576
577
578
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
579
                                len(stp.token_ids) for stp in output.outputs)
580
581
582
583
584
                            out_spd = total_out_toks / pbar.format_dict[
                                "elapsed"]
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
585
586
587
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
588
589
590
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
591
        return sorted(outputs, key=lambda x: int(x.request_id))