llm.py 30 KB
Newer Older
1
2
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
3

4
from tqdm import tqdm
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
8
9
10
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
                                         apply_chat_template,
                                         parse_chat_messages)
11
12
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
13
from vllm.logger import init_logger
14
from vllm.lora.request import LoRARequest
15
16
17
from vllm.model_executor.guided_decoding import (
    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
18
19
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
20
from vllm.prompt_adapter.request import PromptAdapterRequest
Woosuk Kwon's avatar
Woosuk Kwon committed
21
from vllm.sampling_params import SamplingParams
22
23
24
from vllm.transformers_utils.tokenizer import (AnyTokenizer,
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
25
from vllm.usage.usage_lib import UsageContext
26
from vllm.utils import Counter, deprecate_kwargs
27

28
29
logger = init_logger(__name__)

30
31

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
32
33
34
35
36
37
38
39
40
41
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
42
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
43
44
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
45
46
47
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
48
49
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
53
54
55
56
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
57
        quantization: The method used to quantize the model weights. Currently,
58
59
60
61
62
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
63
64
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
65
66
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
67
68
69
70
71
72
73
74
75
76
77
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
78
79
80
81
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
82
83
84
85
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
86
87
88
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
89
90
            When a sequence has context length larger than this, we fall back
            to eager mode.
91
        disable_custom_all_reduce: See ParallelConfig
92
93
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
94

95
96
97
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
98
    """
99

100
101
102
103
104
105
106
107
108
109
110
111
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

112
113
114
    def __init__(
        self,
        model: str,
115
        tokenizer: Optional[str] = None,
116
        tokenizer_mode: str = "auto",
117
        skip_tokenizer_init: bool = False,
118
        trust_remote_code: bool = False,
119
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
120
        dtype: str = "auto",
121
        quantization: Optional[str] = None,
122
        revision: Optional[str] = None,
123
        tokenizer_revision: Optional[str] = None,
124
125
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
126
        swap_space: float = 4,
127
        cpu_offload_gb: float = 0,
128
        enforce_eager: Optional[bool] = None,
129
130
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
131
        disable_custom_all_reduce: bool = False,
132
        disable_async_output_proc: bool = False,
133
134
        **kwargs,
    ) -> None:
135
136
137
138
139
140
141
142
143
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
        it defaults to False for decoder-only models and True
        for encoder/decoder models, since encoder/decoder models
        do not currently support CUDAGraph.
        '''

144
145
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
nunjunj's avatar
nunjunj committed
146
147
148
149
150
151
        removed_vision_keys = (
            "image_token_id",
            "image_feature_size",
            "image_input_shape",
            "image_input_type",
        )
152
153
154
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
155
        engine_args = EngineArgs(
156
            model=model,
157
            tokenizer=tokenizer,
158
            tokenizer_mode=tokenizer_mode,
159
            skip_tokenizer_init=skip_tokenizer_init,
160
            trust_remote_code=trust_remote_code,
161
162
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
163
            quantization=quantization,
164
            revision=revision,
165
            tokenizer_revision=tokenizer_revision,
166
167
168
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
169
            cpu_offload_gb=cpu_offload_gb,
170
171
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
172
            max_seq_len_to_capture=max_seq_len_to_capture,
173
            disable_custom_all_reduce=disable_custom_all_reduce,
174
            disable_async_output_proc=disable_async_output_proc,
175
176
            **kwargs,
        )
yhu422's avatar
yhu422 committed
177
178
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
179
180
        self.request_counter = Counter()

181
182
183
184
185
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
186

187
188
189
190
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
191
            tokenizer_group.tokenizer = tokenizer
192
        else:
193
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
194

195
196
197
198
199
200
201
202
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
203
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
204
205
206
207
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
208
209
    def generate(
        self,
210
        prompts: List[str],
211
212
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
213
        prompt_token_ids: Optional[List[List[int]]] = None,
214
        use_tqdm: bool = True,
215
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
216
217
218
219
220
221
222
223
224
225
226
227
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
228
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
229
230
231
232
233
234
235
236
237
238
239
240
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
241
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
242
243
244
245
246
247
248
249
250
251
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
252
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
253
254
255
256
257
258
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
259
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
260
261
262
263
264
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
265
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
266
267
268
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
269
270
271
272
273
274
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
275
276
    def generate(
        self,
277
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
278
279
280
281
282
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
283
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
284
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
285
286
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None
287
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
        """Generates the completions for the input prompts.

290
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
293
294
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
295
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
296
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
297
298
299
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
300
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
301
            use_tqdm: Whether to use tqdm to display the progress bar.
302
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
303
            prompt_adapter_request: Prompt Adapter request to use for
304
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
305
306

        Returns:
nunjunj's avatar
nunjunj committed
307
            A list of ``RequestOutput`` objects containing the
308
            generated completions in the same order as the input prompts.
309
310
311
312
313

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
314
        """
315
316
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
317
318
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
319

320
        if prompt_token_ids is not None:
321
322
323
324
325
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
326
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
327

328
329
330
331
332
333
334
335
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

336
337
338
339
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

340
341
342
343
        self._validate_and_add_requests(
            inputs=inputs,
            params=sampling_params,
            lora_request=lora_request,
344
345
            prompt_adapter_request=prompt_adapter_request,
            guided_options=guided_options_request)
346

347
348
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
349

nunjunj's avatar
nunjunj committed
350
351
352
353
354
355
356
357
    def chat(
        self,
        messages: List[ChatCompletionMessageParam],
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
358
        add_generation_prompt: bool = True,
nunjunj's avatar
nunjunj committed
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
    ) -> List[RequestOutput]:
        """
        Generates responses for chat messages.

        Converts the messages to prompts using the tokenizer and calls
        the :meth:`generate` method to generate the responses.

        Args:
            messages: A list of messages to generate responses for. Each
                message is a list of dictionaries with 'role' and 'content'
                keys.
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
379
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
                to each message.

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """

        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()

        conversations, _ = parse_chat_messages(messages, model_config,
                                               tokenizer)

        prompts = apply_chat_template(
            tokenizer,
            conversations,
            chat_template=chat_template,
397
            add_generation_prompt=add_generation_prompt)
nunjunj's avatar
nunjunj committed
398
399
400
401
402
403
404
405

        return self.generate(
            prompts,
            sampling_params,
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

406
407
408
409
410
411
412
413
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
414
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
415
416
    ) -> List[EmbeddingRequestOutput]:
        ...
417

418
    @overload  # LEGACY: multi (prompt + optional token ids)
419
420
    def encode(
        self,
421
        prompts: List[str],
422
        pooling_params: Optional[Union[PoolingParams,
423
                                       Sequence[PoolingParams]]] = None,
424
425
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
426
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
427
428
429
430
431
432
433
434
435
436
437
438
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
439
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
440
441
442
443
444
445
446
447
448
449
450
451
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
452
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
453
454
455
456
457
458
459
460
461
462
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
463
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
464
465
466
467
468
469
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
470
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
471
472
473
474
475
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
476
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
477
478
479
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
480
481
482
483
484
485
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
486
487
    def encode(
        self,
488
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
489
490
491
492
493
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
494
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
495
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
496
497
498
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

499
        This class automatically batches the given prompts, considering
500
501
502
503
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
504
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
505
                batch inference. See :class:`~vllm.inputs.PromptInputs`
506
                for more details about the format of each input.
507
508
509
510
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
511
            prompt_adapter_request: Prompt Adapter request to use for
512
                generation, if any.
513
514
515
516

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
517
518
519
520
521

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
522
        """
523
524
525
526
527
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

528
        if prompt_token_ids is not None:
529
530
531
532
533
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
534
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
535

536
537
538
539
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

540
541
542
543
        self._validate_and_add_requests(
            inputs=inputs,
            params=pooling_params,
            lora_request=lora_request,
544
            prompt_adapter_request=prompt_adapter_request,
545
546
        )

547
548
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
549

550
551
    # LEGACY
    def _convert_v1_inputs(
552
553
        self,
        prompts: Optional[Union[str, List[str]]],
554
555
556
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
557

558
559
560
561
562
563
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
564

565
        num_requests = None
566
567
        if prompts is not None:
            num_requests = len(prompts)
568
569
570
571
572
573
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

574
            num_requests = len(prompt_token_ids)
575
576
577
578
579
580
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

        inputs: List[PromptInputs] = []
        for i in range(num_requests):
581
582
            item: PromptInputs

583
            if prompts is not None:
584
585
586
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
587
            else:
588
                raise AssertionError
589
590
591
592
593
594
595

            inputs.append(item)

        return inputs

    def _validate_and_add_requests(
        self,
596
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
597
598
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
599
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
600
        prompt_adapter_request: Optional[PromptAdapterRequest],
601
        guided_options: Optional[GuidedDecodingRequest] = None,
602
603
604
605
606
607
    ) -> None:
        if isinstance(inputs, (str, dict)):
            # Convert a single prompt to a list.
            inputs = [inputs]

        num_requests = len(inputs)
608
609
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
610
                             "must be the same.")
611
612
613
614
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
615

616
617
618
619
620
621
622
623
624
        if isinstance(params, list):
            params = [
                self._add_guided_processor(param, guided_options)
                if isinstance(param, SamplingParams) else param
                for param in params
            ]
        elif isinstance(params, SamplingParams):
            params = self._add_guided_processor(params, guided_options)

Zhuohan Li's avatar
Zhuohan Li committed
625
        # Add requests to the engine.
626
627
628
629
        for i, request_inputs in enumerate(inputs):
            self._add_request(
                request_inputs,
                params[i] if isinstance(params, Sequence) else params,
630
631
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
632
633
                prompt_adapter_request=prompt_adapter_request,
            )
634

635
    def _add_request(
nunjunj's avatar
nunjunj committed
636
637
638
        self,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
639
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
640
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
641
642
    ) -> None:
        request_id = str(next(self.request_counter))
643
644
645
646
647
        self.llm_engine.add_request(
            request_id,
            inputs,
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
648
649
            prompt_adapter_request=prompt_adapter_request,
        )
650

651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
    def _add_guided_processor(
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
        if guided_options:
            if guided_options.guided_decoding_backend is None:
                decoding_config = self.llm_engine.get_decoding_config()
                guided_options.guided_decoding_backend = (
                    decoding_config.guided_decoding_backend)
            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
                guided_options.guided_decoding_backend, guided_options,
                self.get_tokenizer())
            if guided_logits_processor:
                if params.logits_processors is None:
                    params.logits_processors = []
                params.logits_processors.append(guided_logits_processor)
        return params

669
    def _run_engine(
670
            self, *, use_tqdm: bool
671
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
672
673
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
674
            num_requests = self.llm_engine.get_num_unfinished_requests()
675
676
677
678
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
679
680
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
681
            )
682
683
684
685

        # In the loop below, only finished outputs are used
        self.llm_engine.step_return_finished_only = True

Zhuohan Li's avatar
Zhuohan Li committed
686
        # Run the engine.
687
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
688
689
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
690
691
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
692
            for output in step_outputs:
693
                if output.finished:
694
695
                    outputs.append(output)
                    if use_tqdm:
696
697
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
698
699
700
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
701
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
702
703
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
704
705
706
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
707
                        pbar.update(1)
708
709
710
711

        # Restore original behavior
        self.llm_engine.step_return_finished_only = False

712
713
        if use_tqdm:
            pbar.close()
714
715
716
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
717
        return sorted(outputs, key=lambda x: int(x.request_id))
718
719
720
721
722
723

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()