llm.py 31 KB
Newer Older
1
from contextlib import contextmanager
2
3
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
                    overload)
4

5
from tqdm import tqdm
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
9
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
10
11
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
12
                                         parse_chat_messages)
13
14
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.inputs.parse import parse_and_batch_prompt
15
from vllm.logger import init_logger
16
from vllm.lora.request import LoRARequest
17
18
19
from vllm.model_executor.guided_decoding import (
    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
20
21
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
22
from vllm.prompt_adapter.request import PromptAdapterRequest
23
from vllm.sampling_params import RequestOutputKind, SamplingParams
24
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
25
26
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
27
from vllm.usage.usage_lib import UsageContext
28
from vllm.utils import Counter, deprecate_kwargs, is_list_of
29

30
31
logger = init_logger(__name__)

32
33

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
42
43
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
44
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
45
46
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
47
48
49
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
50
51
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
59
        quantization: The method used to quantize the model weights. Currently,
60
            we support "awq", "gptq", and "fp8" (experimental).
61
62
63
64
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
65
66
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
67
68
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
69
70
71
72
73
74
75
76
77
78
79
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
80
81
82
83
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
84
85
86
87
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
88
89
90
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
91
            When a sequence has context length larger than this, we fall back
92
93
94
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
95
        disable_custom_all_reduce: See ParallelConfig
96
97
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
98

99
100
101
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
102
    """
103

104
105
106
107
108
109
110
111
112
113
114
115
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

116
117
118
    def __init__(
        self,
        model: str,
119
        tokenizer: Optional[str] = None,
120
        tokenizer_mode: str = "auto",
121
        skip_tokenizer_init: bool = False,
122
        trust_remote_code: bool = False,
123
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
        dtype: str = "auto",
125
        quantization: Optional[str] = None,
126
        revision: Optional[str] = None,
127
        tokenizer_revision: Optional[str] = None,
128
129
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
130
        swap_space: float = 4,
131
        cpu_offload_gb: float = 0,
132
        enforce_eager: Optional[bool] = None,
133
134
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
135
        disable_custom_all_reduce: bool = False,
136
        disable_async_output_proc: bool = False,
137
138
        **kwargs,
    ) -> None:
139
140
141
142
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
143
        it defaults to False.
144
145
        '''

146
147
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
nunjunj's avatar
nunjunj committed
148
149
150
151
152
153
        removed_vision_keys = (
            "image_token_id",
            "image_feature_size",
            "image_input_shape",
            "image_input_type",
        )
154
155
156
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
157
        engine_args = EngineArgs(
158
            model=model,
159
            tokenizer=tokenizer,
160
            tokenizer_mode=tokenizer_mode,
161
            skip_tokenizer_init=skip_tokenizer_init,
162
            trust_remote_code=trust_remote_code,
163
164
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
165
            quantization=quantization,
166
            revision=revision,
167
            tokenizer_revision=tokenizer_revision,
168
169
170
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
171
            cpu_offload_gb=cpu_offload_gb,
172
173
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
174
            max_seq_len_to_capture=max_seq_len_to_capture,
175
            disable_custom_all_reduce=disable_custom_all_reduce,
176
            disable_async_output_proc=disable_async_output_proc,
177
178
            **kwargs,
        )
yhu422's avatar
yhu422 committed
179
180
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
181
182
        self.request_counter = Counter()

183
184
185
186
187
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
188

189
190
191
192
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
193
            tokenizer_group.tokenizer = tokenizer
194
        else:
195
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
196

197
198
199
200
201
202
203
204
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
205
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
206
207
208
209
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
210
211
    def generate(
        self,
212
        prompts: List[str],
213
214
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
215
        prompt_token_ids: Optional[List[List[int]]] = None,
216
        use_tqdm: bool = True,
217
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
218
219
220
221
222
223
224
225
226
227
228
229
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
230
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
231
232
233
234
235
236
237
238
239
240
241
242
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
243
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
244
245
246
247
248
249
250
251
252
253
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
254
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
255
256
257
258
259
260
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
261
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
262
263
264
265
266
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
267
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
268
269
270
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
271
272
273
274
275
276
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
277
278
    def generate(
        self,
279
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
280
281
282
283
284
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
285
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
286
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
287
288
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None
289
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
290
291
        """Generates the completions for the input prompts.

292
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
293
294
295
296
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
297
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
298
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
299
300
301
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
302
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
303
            use_tqdm: Whether to use tqdm to display the progress bar.
304
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
305
            prompt_adapter_request: Prompt Adapter request to use for
306
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
307
308

        Returns:
nunjunj's avatar
nunjunj committed
309
            A list of ``RequestOutput`` objects containing the
310
            generated completions in the same order as the input prompts.
311
312
313
314
315

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
316
        """
317
318
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
319
320
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
321

322
        if prompt_token_ids is not None:
323
324
325
326
327
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
328
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
329

330
331
332
333
334
335
336
337
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

338
339
340
341
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

342
343
344
345
        self._validate_and_add_requests(
            inputs=inputs,
            params=sampling_params,
            lora_request=lora_request,
346
347
            prompt_adapter_request=prompt_adapter_request,
            guided_options=guided_options_request)
348

349
350
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
351

nunjunj's avatar
nunjunj committed
352
353
354
355
356
357
358
359
    def chat(
        self,
        messages: List[ChatCompletionMessageParam],
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
360
        add_generation_prompt: bool = True,
361
        tools: Optional[List[Dict[str, Any]]] = None,
nunjunj's avatar
nunjunj committed
362
363
    ) -> List[RequestOutput]:
        """
364
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
365

366
367
368
369
370
371
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
372
373

        Args:
374
375
            messages: A single conversation represented as a list of messages.
                Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
376
377
378
379
380
381
382
383
384
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
385
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
386
387
388
389
390
391
392
393
394
395
                to each message.

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """

        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()

396
397
        conversation, mm_data = parse_chat_messages(messages, model_config,
                                                    tokenizer)
nunjunj's avatar
nunjunj committed
398

399
400
401
402
403
404
405
        prompt: Union[str, List[int]]
        if isinstance(tokenizer, MistralTokenizer):
            prompt = apply_mistral_chat_template(
                tokenizer,
                messages=messages,
                chat_template=chat_template,
                add_generation_prompt=add_generation_prompt,
406
                tools=tools,
407
408
409
410
411
412
413
            )
        else:
            prompt = apply_hf_chat_template(
                tokenizer,
                conversation=conversation,
                chat_template=chat_template,
                add_generation_prompt=add_generation_prompt,
414
                tools=tools,
415
            )
nunjunj's avatar
nunjunj committed
416

417
        inputs: PromptInputs
418
        if is_list_of(prompt, int):
419
420
421
422
            inputs = TokensPrompt(prompt_token_ids=prompt)
        else:
            inputs = TextPrompt(prompt=prompt)

423
424
425
        if mm_data is not None:
            inputs["multi_modal_data"] = mm_data

nunjunj's avatar
nunjunj committed
426
        return self.generate(
427
428
            inputs,
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
429
430
431
432
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

433
434
435
436
437
438
439
440
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
441
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
442
443
    ) -> List[EmbeddingRequestOutput]:
        ...
444

445
    @overload  # LEGACY: multi (prompt + optional token ids)
446
447
    def encode(
        self,
448
        prompts: List[str],
449
        pooling_params: Optional[Union[PoolingParams,
450
                                       Sequence[PoolingParams]]] = None,
451
452
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
453
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
454
455
456
457
458
459
460
461
462
463
464
465
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
466
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
467
468
469
470
471
472
473
474
475
476
477
478
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
479
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
480
481
482
483
484
485
486
487
488
489
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
490
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
491
492
493
494
495
496
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
497
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
498
499
500
501
502
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
503
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
504
505
506
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
507
508
509
510
511
512
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
513
514
    def encode(
        self,
515
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
516
517
518
519
520
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
521
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
522
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
523
524
525
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

526
        This class automatically batches the given prompts, considering
527
528
529
530
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
531
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
532
                batch inference. See :class:`~vllm.inputs.PromptInputs`
533
                for more details about the format of each input.
534
535
536
537
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
538
            prompt_adapter_request: Prompt Adapter request to use for
539
                generation, if any.
540
541
542
543

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
544
545
546
547
548

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
549
        """
550
551
552
553
554
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

555
        if prompt_token_ids is not None:
556
557
558
559
560
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
561
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
562

563
564
565
566
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

567
568
569
570
        self._validate_and_add_requests(
            inputs=inputs,
            params=pooling_params,
            lora_request=lora_request,
571
            prompt_adapter_request=prompt_adapter_request,
572
573
        )

574
575
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
576

577
578
579
580
581
582
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

583
584
    # LEGACY
    def _convert_v1_inputs(
585
586
        self,
        prompts: Optional[Union[str, List[str]]],
587
588
589
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
590

591
592
593
594
595
596
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
597

598
        num_requests = None
599
600
        if prompts is not None:
            num_requests = len(prompts)
601
602
603
604
605
606
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

607
            num_requests = len(prompt_token_ids)
608
609
610
611
612
613
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

        inputs: List[PromptInputs] = []
        for i in range(num_requests):
614
615
            item: PromptInputs

616
            if prompts is not None:
617
618
619
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
620
            else:
621
                raise AssertionError
622
623
624
625
626
627
628

            inputs.append(item)

        return inputs

    def _validate_and_add_requests(
        self,
629
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
630
631
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
632
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
633
        prompt_adapter_request: Optional[PromptAdapterRequest],
634
        guided_options: Optional[GuidedDecodingRequest] = None,
635
636
637
638
639
640
    ) -> None:
        if isinstance(inputs, (str, dict)):
            # Convert a single prompt to a list.
            inputs = [inputs]

        num_requests = len(inputs)
641
642
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
643
                             "must be the same.")
644
645
646
647
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
648

649
650
651
652
653
654
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
                self._add_guided_processor(sp, guided_options)

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
655

Zhuohan Li's avatar
Zhuohan Li committed
656
        # Add requests to the engine.
657
658
659
660
        for i, request_inputs in enumerate(inputs):
            self._add_request(
                request_inputs,
                params[i] if isinstance(params, Sequence) else params,
661
662
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
663
664
                prompt_adapter_request=prompt_adapter_request,
            )
665

666
    def _add_request(
nunjunj's avatar
nunjunj committed
667
668
669
        self,
        inputs: PromptInputs,
        params: Union[SamplingParams, PoolingParams],
670
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
671
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
672
673
    ) -> None:
        request_id = str(next(self.request_counter))
674
675
676
677
678
        self.llm_engine.add_request(
            request_id,
            inputs,
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
679
680
            prompt_adapter_request=prompt_adapter_request,
        )
681

682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
    def _add_guided_processor(
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
        if guided_options:
            if guided_options.guided_decoding_backend is None:
                decoding_config = self.llm_engine.get_decoding_config()
                guided_options.guided_decoding_backend = (
                    decoding_config.guided_decoding_backend)
            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
                guided_options.guided_decoding_backend, guided_options,
                self.get_tokenizer())
            if guided_logits_processor:
                if params.logits_processors is None:
                    params.logits_processors = []
                params.logits_processors.append(guided_logits_processor)
        return params

700
    def _run_engine(
701
            self, *, use_tqdm: bool
702
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
703
704
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
705
            num_requests = self.llm_engine.get_num_unfinished_requests()
706
707
708
709
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
710
711
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
712
            )
713

Zhuohan Li's avatar
Zhuohan Li committed
714
        # Run the engine.
715
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
716
717
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
718
719
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
720
            for output in step_outputs:
721
                if output.finished:
722
723
                    outputs.append(output)
                    if use_tqdm:
724
725
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
726
                            assert output.prompt_token_ids is not None
727
728
729
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
730
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
731
732
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
733
734
735
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
736
                        pbar.update(1)
737

738
739
        if use_tqdm:
            pbar.close()
740
741
742
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
743
        return sorted(outputs, key=lambda x: int(x.request_id))
744
745
746
747
748
749

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()