llm.py 27.3 KB
Newer Older
1
2
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
3
4

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
9
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
10
                         parse_and_batch_prompt)
11
from vllm.logger import init_logger
12
from vllm.lora.request import LoRARequest
13
14
15
from vllm.model_executor.guided_decoding import (
    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
16
17
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
18
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
19
from vllm.sampling_params import SamplingParams
20
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
yhu422's avatar
yhu422 committed
21
from vllm.usage.usage_lib import UsageContext
22
from vllm.utils import Counter, deprecate_kwargs
23

24
25
logger = init_logger(__name__)

26
27

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
28
29
30
31
32
33
34
35
36
37
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
38
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
39
40
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
41
42
43
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
44
45
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
49
50
51
52
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
53
        quantization: The method used to quantize the model weights. Currently,
54
55
56
57
58
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
59
60
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
61
62
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
63
64
65
66
67
68
69
70
71
72
73
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
74
75
76
77
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
78
79
80
81
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
82
83
84
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
85
86
            When a sequence has context length larger than this, we fall back
            to eager mode.
87
        disable_custom_all_reduce: See ParallelConfig
88
89
90
91
92
93
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
    
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
94
    """
95

96
97
98
99
100
101
102
103
104
105
106
107
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

108
109
110
    def __init__(
        self,
        model: str,
111
        tokenizer: Optional[str] = None,
112
        tokenizer_mode: str = "auto",
113
        skip_tokenizer_init: bool = False,
114
        trust_remote_code: bool = False,
115
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
116
        dtype: str = "auto",
117
        quantization: Optional[str] = None,
118
        revision: Optional[str] = None,
119
        tokenizer_revision: Optional[str] = None,
120
121
122
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
123
        cpu_offload_gb: float = 0,
124
        enforce_eager: Optional[bool] = None,
125
126
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
127
        disable_custom_all_reduce: bool = False,
128
129
        **kwargs,
    ) -> None:
130
131
132
133
134
135
136
137
138
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
        it defaults to False for decoder-only models and True
        for encoder/decoder models, since encoder/decoder models
        do not currently support CUDAGraph.
        '''

139
140
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
141
142
143
144
145
        removed_vision_keys = ("image_token_id", "image_feature_size",
                               "image_input_shape", "image_input_type")
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
146
        engine_args = EngineArgs(
147
            model=model,
148
            tokenizer=tokenizer,
149
            tokenizer_mode=tokenizer_mode,
150
            skip_tokenizer_init=skip_tokenizer_init,
151
            trust_remote_code=trust_remote_code,
152
153
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
154
            quantization=quantization,
155
            revision=revision,
156
            tokenizer_revision=tokenizer_revision,
157
158
159
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
160
            cpu_offload_gb=cpu_offload_gb,
161
162
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
163
            max_seq_len_to_capture=max_seq_len_to_capture,
164
            disable_custom_all_reduce=disable_custom_all_reduce,
165
166
            **kwargs,
        )
yhu422's avatar
yhu422 committed
167
168
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
169
170
        self.request_counter = Counter()

171
    def get_tokenizer(
172
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
173
        return self.llm_engine.tokenizer.tokenizer
174

175
176
177
178
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
179
180
181
182
183
184
185
186
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
            self.llm_engine.tokenizer.tokenizer = tokenizer
        else:
            self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer(
                tokenizer)
187

188
189
190
191
192
193
194
195
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
196
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
197
198
199
200
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
201
202
    def generate(
        self,
203
        prompts: List[str],
204
205
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
206
        prompt_token_ids: Optional[List[List[int]]] = None,
207
        use_tqdm: bool = True,
208
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
209
210
211
212
213
214
215
216
217
218
219
220
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
221
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
222
223
224
225
226
227
228
229
230
231
232
233
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
234
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
235
236
237
238
239
240
241
242
243
244
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
245
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
246
247
248
249
250
251
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
252
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
253
254
255
256
257
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
258
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
259
260
261
262
263
264
265
266
267
268
    ) -> List[RequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def generate(
        self,
269
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
270
271
272
273
274
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
275
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
276
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
277
278
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None
279
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
280
281
        """Generates the completions for the input prompts.

282
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
283
284
285
286
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
287
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
288
            sampling_params: The sampling parameters for text generation. If
289
290
291
292
                None, we use the default sampling parameters. 
                When it is a single value, it is applied to every prompt. 
                When it is a list, the list must have the same length as the 
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
293
            use_tqdm: Whether to use tqdm to display the progress bar.
294
            lora_request: LoRA request to use for generation, if any.
295
296
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
297
298

        Returns:
299
300
            A list of `RequestOutput` objects containing the
            generated completions in the same order as the input prompts.
301
302
303
304
305

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
306
        """
307
308
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
309
310
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
311

312
        if prompt_token_ids is not None:
313
314
315
316
317
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
318
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
319

320
321
322
323
324
325
326
327
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

328
329
330
331
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

332
333
334
335
        self._validate_and_add_requests(
            inputs=inputs,
            params=sampling_params,
            lora_request=lora_request,
336
337
            prompt_adapter_request=prompt_adapter_request,
            guided_options=guided_options_request)
338

339
340
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
341

342
343
344
345
346
347
348
349
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
350
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
351
352
    ) -> List[EmbeddingRequestOutput]:
        ...
353

354
    @overload  # LEGACY: multi (prompt + optional token ids)
355
356
    def encode(
        self,
357
        prompts: List[str],
358
        pooling_params: Optional[Union[PoolingParams,
359
                                       Sequence[PoolingParams]]] = None,
360
361
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
362
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
363
364
365
366
367
368
369
370
371
372
373
374
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
375
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
376
377
378
379
380
381
382
383
384
385
386
387
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
388
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
389
390
391
392
393
394
395
396
397
398
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
399
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
400
401
402
403
404
405
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
406
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
407
408
409
410
411
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
412
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
413
414
415
416
417
418
419
420
421
422
    ) -> List[EmbeddingRequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def encode(
        self,
423
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
424
425
426
427
428
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
429
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
430
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
431
432
433
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

434
        This class automatically batches the given prompts, considering
435
436
437
438
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
439
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
440
                batch inference. See :class:`~vllm.inputs.PromptInputs`
441
                for more details about the format of each input.
442
443
444
445
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
446
447
            prompt_adapter_request: Prompt Adapter request to use for 
                generation, if any.
448
449
450
451

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
452
453
454
455
456

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
457
        """
458
459
460
461
462
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

463
        if prompt_token_ids is not None:
464
465
466
467
468
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
469
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
470

471
472
473
474
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

475
476
477
478
        self._validate_and_add_requests(
            inputs=inputs,
            params=pooling_params,
            lora_request=lora_request,
479
            prompt_adapter_request=prompt_adapter_request,
480
481
        )

482
483
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
484

485
486
    # LEGACY
    def _convert_v1_inputs(
487
488
        self,
        prompts: Optional[Union[str, List[str]]],
489
490
491
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
492

493
494
495
496
497
498
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
499

500
        num_requests = None
501
502
        if prompts is not None:
            num_requests = len(prompts)
503
504
505
506
507
508
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

509
            num_requests = len(prompt_token_ids)
510
511
512
513
514
515
516
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

        inputs: List[PromptInputs] = []
        for i in range(num_requests):
            if prompts is not None:
517
518
519
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
520
            else:
521
                raise AssertionError
522
523
524
525
526
527
528

            inputs.append(item)

        return inputs

    def _validate_and_add_requests(
        self,
529
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
530
531
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
532
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
533
        prompt_adapter_request: Optional[PromptAdapterRequest],
534
        guided_options: Optional[GuidedDecodingRequest] = None,
535
536
537
538
539
540
    ) -> None:
        if isinstance(inputs, (str, dict)):
            # Convert a single prompt to a list.
            inputs = [inputs]

        num_requests = len(inputs)
541

542
543
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
544
                             "must be the same.")
545
546
547
548
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
549

550
551
552
553
554
555
556
557
558
        if isinstance(params, list):
            params = [
                self._add_guided_processor(param, guided_options)
                if isinstance(param, SamplingParams) else param
                for param in params
            ]
        elif isinstance(params, SamplingParams):
            params = self._add_guided_processor(params, guided_options)

Zhuohan Li's avatar
Zhuohan Li committed
559
        # Add requests to the engine.
560
561
562
563
        for i, request_inputs in enumerate(inputs):
            self._add_request(
                request_inputs,
                params[i] if isinstance(params, Sequence) else params,
564
565
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
566
                prompt_adapter_request=prompt_adapter_request)
567

568
    def _add_request(
569
570
571
572
573
574
            self,
            inputs: PromptInputs,
            params: Union[SamplingParams, PoolingParams],
            lora_request: Optional[Union[List[LoRARequest],
                                         LoRARequest]] = None,
            prompt_adapter_request: Optional[PromptAdapterRequest] = None
575
576
    ) -> None:
        request_id = str(next(self.request_counter))
577
578
579
580
581
582
        self.llm_engine.add_request(
            request_id,
            inputs,
            params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request)
583

584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
    def _add_guided_processor(
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
        if guided_options:
            if guided_options.guided_decoding_backend is None:
                decoding_config = self.llm_engine.get_decoding_config()
                guided_options.guided_decoding_backend = (
                    decoding_config.guided_decoding_backend)
            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
                guided_options.guided_decoding_backend, guided_options,
                self.get_tokenizer())
            if guided_logits_processor:
                if params.logits_processors is None:
                    params.logits_processors = []
                params.logits_processors.append(guided_logits_processor)
        return params

602
    def _run_engine(
603
            self, *, use_tqdm: bool
604
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
605
606
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
607
            num_requests = self.llm_engine.get_num_unfinished_requests()
608
609
610
611
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
612
613
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
614
            )
Zhuohan Li's avatar
Zhuohan Li committed
615
        # Run the engine.
616
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
617
618
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
619
620
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
621
            for output in step_outputs:
622
                if output.finished:
623
624
                    outputs.append(output)
                    if use_tqdm:
625
626
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
627
628
629
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
630
                                len(stp.token_ids) for stp in output.outputs)
631
632
633
634
635
                            out_spd = total_out_toks / pbar.format_dict[
                                "elapsed"]
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
636
637
638
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
639
640
641
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
642
        return sorted(outputs, key=lambda x: int(x.request_id))
643
644
645
646
647
648

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()