llm.py 26.8 KB
Newer Older
1
2
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
3
4

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
9
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
10
                         parse_and_batch_prompt)
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
14
15
from vllm.model_executor.guided_decoding import (
    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
16
17
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
18
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.sampling_params import SamplingParams
20
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
yhu422's avatar
yhu422 committed
21
from vllm.usage.usage_lib import UsageContext
22
from vllm.utils import Counter, deprecate_kwargs
23

24
25
logger = init_logger(__name__)

26
27

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
36
37
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
38
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
39
40
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
41
42
43
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
44
45
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
52
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
53
        quantization: The method used to quantize the model weights. Currently,
54
55
56
57
58
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
59
60
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
61
62
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
63
64
65
66
67
68
69
70
71
72
73
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
74
75
76
77
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
78
79
80
81
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
82
83
84
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
85
86
            When a sequence has context length larger than this, we fall back
            to eager mode.
87
        disable_custom_all_reduce: See ParallelConfig
88
89
90
91
92
93
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
    
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
94
    """
95

96
97
98
99
100
101
102
103
104
105
106
107
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

108
109
110
    def __init__(
        self,
        model: str,
111
        tokenizer: Optional[str] = None,
112
        tokenizer_mode: str = "auto",
113
        skip_tokenizer_init: bool = False,
114
        trust_remote_code: bool = False,
115
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
116
        dtype: str = "auto",
117
        quantization: Optional[str] = None,
118
        revision: Optional[str] = None,
119
        tokenizer_revision: Optional[str] = None,
120
121
122
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
123
        cpu_offload_gb: float = 0,
124
        enforce_eager: bool = False,
125
126
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
127
        disable_custom_all_reduce: bool = False,
128
129
130
131
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
132
133
134
135
136
        removed_vision_keys = ("image_token_id", "image_feature_size",
                               "image_input_shape", "image_input_type")
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
137
        engine_args = EngineArgs(
138
            model=model,
139
            tokenizer=tokenizer,
140
            tokenizer_mode=tokenizer_mode,
141
            skip_tokenizer_init=skip_tokenizer_init,
142
            trust_remote_code=trust_remote_code,
143
144
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
145
            quantization=quantization,
146
            revision=revision,
147
            tokenizer_revision=tokenizer_revision,
148
149
150
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
151
            cpu_offload_gb=cpu_offload_gb,
152
153
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
154
            max_seq_len_to_capture=max_seq_len_to_capture,
155
            disable_custom_all_reduce=disable_custom_all_reduce,
156
157
            **kwargs,
        )
yhu422's avatar
yhu422 committed
158
159
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
160
161
        self.request_counter = Counter()

162
    def get_tokenizer(
163
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
164
        return self.llm_engine.tokenizer.tokenizer
165

166
167
168
169
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
170
171
172
173
174
175
176
177
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
            self.llm_engine.tokenizer.tokenizer = tokenizer
        else:
            self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
                tokenizer)
178

179
180
181
182
183
184
185
186
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
187
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
188
189
190
191
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
192
193
    def generate(
        self,
194
        prompts: List[str],
195
196
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
197
        prompt_token_ids: Optional[List[List[int]]] = None,
198
        use_tqdm: bool = True,
199
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
200
201
202
203
204
205
206
207
208
209
210
211
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
212
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
213
214
215
216
217
218
219
220
221
222
223
224
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
225
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
226
227
228
229
230
231
232
233
234
235
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
236
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
237
238
239
240
241
242
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
243
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
244
245
246
247
248
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
249
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
250
251
252
253
254
255
256
257
258
259
    ) -> List[RequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def generate(
        self,
260
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
261
262
263
264
265
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
266
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
267
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
268
269
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None
270
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
271
272
        """Generates the completions for the input prompts.

273
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
274
275
276
277
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
278
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
279
            sampling_params: The sampling parameters for text generation. If
280
281
282
283
                None, we use the default sampling parameters. 
                When it is a single value, it is applied to every prompt. 
                When it is a list, the list must have the same length as the 
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
284
            use_tqdm: Whether to use tqdm to display the progress bar.
285
            lora_request: LoRA request to use for generation, if any.
286
287
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289

        Returns:
290
291
            A list of `RequestOutput` objects containing the
            generated completions in the same order as the input prompts.
292
293
294
295
296

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
297
        """
298
299
300
301
302
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.generate() is only supported for generation models "
                "(XForCausalLM).")

303
        if prompt_token_ids is not None:
304
305
306
307
308
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
309
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
310

311
312
313
314
315
316
317
318
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

319
320
321
322
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

323
324
325
326
        self._validate_and_add_requests(
            inputs=inputs,
            params=sampling_params,
            lora_request=lora_request,
327
328
            prompt_adapter_request=prompt_adapter_request,
            guided_options=guided_options_request)
329

330
331
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
332

333
334
335
336
337
338
339
340
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
341
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
342
343
    ) -> List[EmbeddingRequestOutput]:
        ...
344

345
    @overload  # LEGACY: multi (prompt + optional token ids)
346
347
    def encode(
        self,
348
        prompts: List[str],
349
        pooling_params: Optional[Union[PoolingParams,
350
                                       Sequence[PoolingParams]]] = None,
351
352
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
353
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
354
355
356
357
358
359
360
361
362
363
364
365
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
366
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
367
368
369
370
371
372
373
374
375
376
377
378
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
379
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
380
381
382
383
384
385
386
387
388
389
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
390
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
391
392
393
394
395
396
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
397
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
398
399
400
401
402
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
403
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
404
405
406
407
408
409
410
411
412
413
    ) -> List[EmbeddingRequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def encode(
        self,
414
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
415
416
417
418
419
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
420
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
421
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
422
423
424
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

425
        This class automatically batches the given prompts, considering
426
427
428
429
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
430
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
431
                batch inference. See :class:`~vllm.inputs.PromptInputs`
432
                for more details about the format of each input.
433
434
435
436
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
437
438
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
439
440
441
442

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
443
444
445
446
447

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
448
        """
449
450
451
452
453
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

454
        if prompt_token_ids is not None:
455
456
457
458
459
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
460
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
461

462
463
464
465
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

466
467
468
469
        self._validate_and_add_requests(
            inputs=inputs,
            params=pooling_params,
            lora_request=lora_request,
470
            prompt_adapter_request=prompt_adapter_request,
471
472
        )

473
474
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
475

476
477
    # LEGACY
    def _convert_v1_inputs(
478
479
        self,
        prompts: Optional[Union[str, List[str]]],
480
481
482
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
483

484
485
486
487
488
489
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
490

491
        num_requests = None
492
493
        if prompts is not None:
            num_requests = len(prompts)
494
495
496
497
498
499
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

500
            num_requests = len(prompt_token_ids)
501
502
503
504
505
506
507
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

        inputs: List[PromptInputs] = []
        for i in range(num_requests):
            if prompts is not None:
508
509
510
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
511
            else:
512
                raise AssertionError
513
514
515
516
517
518
519

            inputs.append(item)

        return inputs

    def _validate_and_add_requests(
        self,
520
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
521
522
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
523
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
524
        prompt_adapter_request: Optional[PromptAdapterRequest],
525
        guided_options: Optional[GuidedDecodingRequest] = None,
526
527
528
529
530
531
    ) -> None:
        if isinstance(inputs, (str, dict)):
            # Convert a single prompt to a list.
            inputs = [inputs]

        num_requests = len(inputs)
532

533
534
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
535
                             "must be the same.")
536
537
538
539
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
540

541
542
543
544
545
546
547
548
549
        if isinstance(params, list):
            params = [
                self._add_guided_processor(param, guided_options)
                if isinstance(param, SamplingParams) else param
                for param in params
            ]
        elif isinstance(params, SamplingParams):
            params = self._add_guided_processor(params, guided_options)

Zhuohan Li's avatar
Zhuohan Li committed
550
        # Add requests to the engine.
551
552
553
554
        for i, request_inputs in enumerate(inputs):
            self._add_request(
                request_inputs,
                params[i] if isinstance(params, Sequence) else params,
555
556
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
557
                prompt_adapter_request=prompt_adapter_request)
558

559
    def _add_request(
560
561
562
563
564
565
            self,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            lora_request: Optional[Union[List[LoRARequest],
                                         LoRARequest]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
566
567
    ) -> None:
        request_id = str(next(self.request_counter))
568
569
570
571
572
573
        self.llm_engine.add_request(
            request_id,
            inputs,
            params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
574

575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
    def _add_guided_processor(
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
        if guided_options:
            if guided_options.guided_decoding_backend is None:
                decoding_config = self.llm_engine.get_decoding_config()
                guided_options.guided_decoding_backend = (
                    decoding_config.guided_decoding_backend)
            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
                guided_options.guided_decoding_backend, guided_options,
                self.get_tokenizer())
            if guided_logits_processor:
                if params.logits_processors is None:
                    params.logits_processors = []
                params.logits_processors.append(guided_logits_processor)
        return params

593
    def _run_engine(
594
            self, *, use_tqdm: bool
595
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
596
597
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
598
            num_requests = self.llm_engine.get_num_unfinished_requests()
599
600
601
602
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
603
604
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
605
            )
Zhuohan Li's avatar
Zhuohan Li committed
606
        # Run the engine.
607
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
608
609
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
610
611
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
612
            for output in step_outputs:
613
                if output.finished:
614
615
                    outputs.append(output)
                    if use_tqdm:
616
617
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
618
619
620
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
621
                                len(stp.token_ids) for stp in output.outputs)
622
623
624
625
626
                            out_spd = total_out_toks / pbar.format_dict[
                                "elapsed"]
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
627
628
629
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
630
631
632
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
633
        return sorted(outputs, key=lambda x: int(x.request_id))