llm.py 25.2 KB
Newer Older
1
2
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
3
4

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
9
10
11
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
                         TextTokensPrompt, TokensPrompt,
                         parse_and_batch_prompt)
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
15
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
16
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
17
from vllm.sampling_params import SamplingParams
18
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
yhu422's avatar
yhu422 committed
19
from vllm.usage.usage_lib import UsageContext
20
from vllm.utils import Counter, deprecate_kwargs
21

22
23
logger = init_logger(__name__)

24
25

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28
29
30
31
32
33
34
35
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
36
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
37
38
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
39
40
41
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
42
43
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
47
48
49
50
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
51
        quantization: The method used to quantize the model weights. Currently,
52
53
54
55
56
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
57
58
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
59
60
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
61
62
63
64
65
66
67
68
69
70
71
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
72
73
74
75
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
76
77
78
79
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
80
81
82
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
83
84
            When a sequence has context length larger than this, we fall back
            to eager mode.
85
        disable_custom_all_reduce: See ParallelConfig
86
87
88
89
90
91
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
    
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
92
    """
93

94
95
96
97
98
99
100
101
102
103
104
105
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

106
107
108
    def __init__(
        self,
        model: str,
109
        tokenizer: Optional[str] = None,
110
        tokenizer_mode: str = "auto",
111
        skip_tokenizer_init: bool = False,
112
        trust_remote_code: bool = False,
113
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
114
        dtype: str = "auto",
115
        quantization: Optional[str] = None,
116
        revision: Optional[str] = None,
117
        tokenizer_revision: Optional[str] = None,
118
119
120
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
121
        cpu_offload_gb: float = 0,
122
        enforce_eager: bool = False,
123
124
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
125
        disable_custom_all_reduce: bool = False,
126
127
128
129
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
130
131
132
133
134
        removed_vision_keys = ("image_token_id", "image_feature_size",
                               "image_input_shape", "image_input_type")
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
135
        engine_args = EngineArgs(
136
            model=model,
137
            tokenizer=tokenizer,
138
            tokenizer_mode=tokenizer_mode,
139
            skip_tokenizer_init=skip_tokenizer_init,
140
            trust_remote_code=trust_remote_code,
141
142
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
143
            quantization=quantization,
144
            revision=revision,
145
            tokenizer_revision=tokenizer_revision,
146
147
148
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
149
            cpu_offload_gb=cpu_offload_gb,
150
151
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
152
            max_seq_len_to_capture=max_seq_len_to_capture,
153
            disable_custom_all_reduce=disable_custom_all_reduce,
154
155
            **kwargs,
        )
yhu422's avatar
yhu422 committed
156
157
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
158
159
        self.request_counter = Counter()

160
    def get_tokenizer(
161
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
162
        return self.llm_engine.tokenizer.tokenizer
163

164
165
166
167
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
168
169
170
171
172
173
174
175
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
            self.llm_engine.tokenizer.tokenizer = tokenizer
        else:
            self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
                tokenizer)
176

177
178
179
180
181
182
183
184
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
185
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
186
187
188
189
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
190
191
    def generate(
        self,
192
        prompts: List[str],
193
194
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
195
        prompt_token_ids: Optional[List[List[int]]] = None,
196
        use_tqdm: bool = True,
197
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
198
199
200
201
202
203
204
205
206
207
208
209
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
210
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
211
212
213
214
215
216
217
218
219
220
221
222
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
223
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
224
225
226
227
228
229
230
231
232
233
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
234
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
235
236
237
238
239
240
241
242
243
244
245
246
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
247
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
    ) -> List[RequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def generate(
        self,
        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
264
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
265
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
266
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
267
268
        """Generates the completions for the input prompts.

269
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
270
271
272
273
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
274
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
275
            sampling_params: The sampling parameters for text generation. If
276
277
278
279
                None, we use the default sampling parameters. 
                When it is a single value, it is applied to every prompt. 
                When it is a list, the list must have the same length as the 
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
280
            use_tqdm: Whether to use tqdm to display the progress bar.
281
            lora_request: LoRA request to use for generation, if any.
282
283
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
284
285

        Returns:
286
287
            A list of `RequestOutput` objects containing the
            generated completions in the same order as the input prompts.
288
289
290
291
292

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
293
        """
294
295
296
297
298
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.generate() is only supported for generation models "
                "(XForCausalLM).")

299
        if prompt_token_ids is not None:
300
301
302
303
304
305
306
307
308
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
            inputs = cast(
                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                prompts)

309
310
311
312
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

313
314
315
316
        self._validate_and_add_requests(
            inputs=inputs,
            params=sampling_params,
            lora_request=lora_request,
317
            prompt_adapter_request=prompt_adapter_request)
318

319
320
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
321

322
323
324
325
326
327
328
329
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
330
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
331
332
    ) -> List[EmbeddingRequestOutput]:
        ...
333

334
    @overload  # LEGACY: multi (prompt + optional token ids)
335
336
    def encode(
        self,
337
        prompts: List[str],
338
        pooling_params: Optional[Union[PoolingParams,
339
                                       Sequence[PoolingParams]]] = None,
340
341
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
342
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
343
344
345
346
347
348
349
350
351
352
353
354
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
355
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
356
357
358
359
360
361
362
363
364
365
366
367
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
368
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
369
370
371
372
373
374
375
376
377
378
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
379
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
380
381
382
383
384
385
386
387
388
389
390
391
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
392
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    ) -> List[EmbeddingRequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def encode(
        self,
        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
409
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
410
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
411
412
413
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

414
        This class automatically batches the given prompts, considering
415
416
417
418
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
419
420
421
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
                batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
                for more details about the format of each input.
422
423
424
425
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
426
427
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
428
429
430
431

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
432
433
434
435
436

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
437
        """
438
439
440
441
442
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

443
        if prompt_token_ids is not None:
444
445
446
447
448
449
450
451
452
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
            inputs = cast(
                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                prompts)

453
454
455
456
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

457
458
459
460
        self._validate_and_add_requests(
            inputs=inputs,
            params=pooling_params,
            lora_request=lora_request,
461
            prompt_adapter_request=prompt_adapter_request,
462
463
        )

464
465
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
466

467
468
    # LEGACY
    def _convert_v1_inputs(
469
470
        self,
        prompts: Optional[Union[str, List[str]]],
471
472
473
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
474

475
476
477
478
479
480
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
481

482
        num_requests = None
483
484
        if prompts is not None:
            num_requests = len(prompts)
485
486
487
488
489
490
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

491
            num_requests = len(prompt_token_ids)
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

        inputs: List[PromptInputs] = []
        for i in range(num_requests):
            if prompts is not None:
                if prompt_token_ids is not None:
                    item = TextTokensPrompt(
                        prompt=prompts[i],
                        prompt_token_ids=prompt_token_ids[i])
                else:
                    item = TextPrompt(prompt=prompts[i])
            else:
                if prompt_token_ids is not None:
                    item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
                else:
                    raise AssertionError

            inputs.append(item)

        return inputs

    def _validate_and_add_requests(
        self,
        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
520
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
521
        prompt_adapter_request: Optional[PromptAdapterRequest],
522
523
524
525
526
527
    ) -> None:
        if isinstance(inputs, (str, dict)):
            # Convert a single prompt to a list.
            inputs = [inputs]

        num_requests = len(inputs)
528

529
530
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
531
                             "must be the same.")
532
533
534
535
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
536

Zhuohan Li's avatar
Zhuohan Li committed
537
        # Add requests to the engine.
538
539
540
541
        for i, request_inputs in enumerate(inputs):
            self._add_request(
                request_inputs,
                params[i] if isinstance(params, Sequence) else params,
542
543
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
544
                prompt_adapter_request=prompt_adapter_request)
545

546
    def _add_request(
547
548
549
550
551
552
            self,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            lora_request: Optional[Union[List[LoRARequest],
                                         LoRARequest]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
553
554
    ) -> None:
        request_id = str(next(self.request_counter))
555
556
557
558
559
560
        self.llm_engine.add_request(
            request_id,
            inputs,
            params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
561

562
    def _run_engine(
563
            self, *, use_tqdm: bool
564
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
565
566
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
567
            num_requests = self.llm_engine.get_num_unfinished_requests()
568
569
570
571
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
572
573
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
574
            )
Zhuohan Li's avatar
Zhuohan Li committed
575
        # Run the engine.
576
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
577
578
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
579
580
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
581
            for output in step_outputs:
582
                if output.finished:
583
584
                    outputs.append(output)
                    if use_tqdm:
585
586
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
587
588
589
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
590
                                len(stp.token_ids) for stp in output.outputs)
591
592
593
594
595
                            out_spd = total_out_toks / pbar.format_dict[
                                "elapsed"]
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
596
597
598
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
599
600
601
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
602
        return sorted(outputs, key=lambda x: int(x.request_id))