llm.py 31.2 KB
Newer Older
1
from contextlib import contextmanager
2
3
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Union, cast,
                    overload)
4

5
from tqdm import tqdm
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
9
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
10
11
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
12
                                         parse_chat_messages)
13
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
14
from vllm.inputs.parse import parse_and_batch_prompt
15
from vllm.logger import init_logger
16
from vllm.lora.request import LoRARequest
17
18
19
from vllm.model_executor.guided_decoding import (
    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
20
21
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
22
from vllm.prompt_adapter.request import PromptAdapterRequest
23
from vllm.sampling_params import RequestOutputKind, SamplingParams
24
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
25
26
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
27
from vllm.usage.usage_lib import UsageContext
28
from vllm.utils import Counter, deprecate_kwargs, is_list_of
29

30
31
logger = init_logger(__name__)

32
33

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
37
38
39
40
41
42
43
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
44
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
45
46
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
47
48
49
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
50
51
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
52
53
54
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
55
56
57
58
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
59
        quantization: The method used to quantize the model weights. Currently,
60
            we support "awq", "gptq", and "fp8" (experimental).
61
62
63
64
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
65
66
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
67
68
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
69
70
71
72
73
74
75
76
77
78
79
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
80
81
82
83
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
84
85
86
87
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
88
89
90
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
91
            When a sequence has context length larger than this, we fall back
92
93
94
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
95
        disable_custom_all_reduce: See ParallelConfig
96
97
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
98

99
100
101
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
102
    """
103

104
105
106
107
108
109
110
111
112
113
114
115
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

116
117
118
    def __init__(
        self,
        model: str,
119
        tokenizer: Optional[str] = None,
120
        tokenizer_mode: str = "auto",
121
        skip_tokenizer_init: bool = False,
122
        trust_remote_code: bool = False,
123
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
        dtype: str = "auto",
125
        quantization: Optional[str] = None,
126
        revision: Optional[str] = None,
127
        tokenizer_revision: Optional[str] = None,
128
129
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
130
        swap_space: float = 4,
131
        cpu_offload_gb: float = 0,
132
        enforce_eager: Optional[bool] = None,
133
134
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
135
        disable_custom_all_reduce: bool = False,
136
        disable_async_output_proc: bool = False,
137
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
138
139
        **kwargs,
    ) -> None:
140
141
142
143
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
144
        it defaults to False.
145
146
        '''

147
148
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
nunjunj's avatar
nunjunj committed
149
150
151
152
153
154
        removed_vision_keys = (
            "image_token_id",
            "image_feature_size",
            "image_input_shape",
            "image_input_type",
        )
155
156
157
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
158
        engine_args = EngineArgs(
159
            model=model,
160
            tokenizer=tokenizer,
161
            tokenizer_mode=tokenizer_mode,
162
            skip_tokenizer_init=skip_tokenizer_init,
163
            trust_remote_code=trust_remote_code,
164
165
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
166
            quantization=quantization,
167
            revision=revision,
168
            tokenizer_revision=tokenizer_revision,
169
170
171
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
172
            cpu_offload_gb=cpu_offload_gb,
173
174
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
175
            max_seq_len_to_capture=max_seq_len_to_capture,
176
            disable_custom_all_reduce=disable_custom_all_reduce,
177
            disable_async_output_proc=disable_async_output_proc,
178
            mm_processor_kwargs=mm_processor_kwargs,
179
180
            **kwargs,
        )
yhu422's avatar
yhu422 committed
181
182
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
183
184
        self.request_counter = Counter()

185
186
187
188
189
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
190

191
192
193
194
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
195
            tokenizer_group.tokenizer = tokenizer
196
        else:
197
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
198

199
200
201
202
203
204
205
206
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
207
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
208
209
210
211
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
212
213
    def generate(
        self,
214
        prompts: List[str],
215
216
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
217
        prompt_token_ids: Optional[List[List[int]]] = None,
218
        use_tqdm: bool = True,
219
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
220
221
222
223
224
225
226
227
228
229
230
231
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
232
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
233
234
235
236
237
238
239
240
241
242
243
244
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
245
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
246
247
248
249
250
251
252
253
254
255
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
256
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
257
258
259
260
261
262
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
263
264
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
265
266
267
268
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
269
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
270
271
272
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
273
274
275
276
277
278
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
279
280
    def generate(
        self,
281
        prompts: Union[Union[PromptType, Sequence[PromptType]],
282
283
284
285
286
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
287
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
288
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
289
290
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None
291
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
292
293
        """Generates the completions for the input prompts.

294
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
295
296
297
298
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
299
300
301
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
302
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
303
304
305
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
306
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
307
            use_tqdm: Whether to use tqdm to display the progress bar.
308
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
309
            prompt_adapter_request: Prompt Adapter request to use for
310
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
311
312

        Returns:
nunjunj's avatar
nunjunj committed
313
            A list of ``RequestOutput`` objects containing the
314
            generated completions in the same order as the input prompts.
315
316
317
318
319

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
320
        """
321
322
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
323
324
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
325

326
        if prompt_token_ids is not None:
327
            parsed_prompts = self._convert_v1_inputs(
328
329
330
331
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
332
333
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
334

335
336
337
338
339
340
341
342
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

343
344
345
346
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

347
        self._validate_and_add_requests(
348
            prompts=parsed_prompts,
349
350
            params=sampling_params,
            lora_request=lora_request,
351
352
            prompt_adapter_request=prompt_adapter_request,
            guided_options=guided_options_request)
353

354
355
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
356

nunjunj's avatar
nunjunj committed
357
358
359
360
361
362
363
364
    def chat(
        self,
        messages: List[ChatCompletionMessageParam],
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
365
        add_generation_prompt: bool = True,
366
        tools: Optional[List[Dict[str, Any]]] = None,
nunjunj's avatar
nunjunj committed
367
368
    ) -> List[RequestOutput]:
        """
369
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
370

371
372
373
374
375
376
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
377
378

        Args:
379
380
            messages: A single conversation represented as a list of messages.
                Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
381
382
383
384
385
386
387
388
389
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
390
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
391
392
393
394
395
396
397
398
399
400
                to each message.

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """

        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()

401
402
        conversation, mm_data = parse_chat_messages(messages, model_config,
                                                    tokenizer)
nunjunj's avatar
nunjunj committed
403

404
        prompt_data: Union[str, List[int]]
405
        if isinstance(tokenizer, MistralTokenizer):
406
            prompt_data = apply_mistral_chat_template(
407
408
409
410
                tokenizer,
                messages=messages,
                chat_template=chat_template,
                add_generation_prompt=add_generation_prompt,
411
                tools=tools,
412
413
            )
        else:
414
            prompt_data = apply_hf_chat_template(
415
416
417
418
                tokenizer,
                conversation=conversation,
                chat_template=chat_template,
                add_generation_prompt=add_generation_prompt,
419
                tools=tools,
420
            )
nunjunj's avatar
nunjunj committed
421

422
423
424
        prompt: PromptType
        if is_list_of(prompt_data, int):
            prompt = TokensPrompt(prompt_token_ids=prompt_data)
425
        else:
426
            prompt = TextPrompt(prompt=prompt_data)
427

428
        if mm_data is not None:
429
            prompt["multi_modal_data"] = mm_data
430

nunjunj's avatar
nunjunj committed
431
        return self.generate(
432
            prompt,
433
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
434
435
436
437
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

438
439
440
441
442
443
444
445
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
446
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
447
448
    ) -> List[EmbeddingRequestOutput]:
        ...
449

450
    @overload  # LEGACY: multi (prompt + optional token ids)
451
452
    def encode(
        self,
453
        prompts: List[str],
454
        pooling_params: Optional[Union[PoolingParams,
455
                                       Sequence[PoolingParams]]] = None,
456
457
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
458
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
459
460
461
462
463
464
465
466
467
468
469
470
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
471
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
472
473
474
475
476
477
478
479
480
481
482
483
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
484
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
485
486
487
488
489
490
491
492
493
494
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
495
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
496
497
498
499
500
501
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
502
503
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
504
505
506
507
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
508
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
509
510
511
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
512
513
514
515
516
517
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
518
519
    def encode(
        self,
520
        prompts: Union[Union[PromptType, Sequence[PromptType]],
521
522
523
524
525
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
526
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
527
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
528
529
530
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

531
        This class automatically batches the given prompts, considering
532
533
534
535
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
536
537
538
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
539
540
541
542
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
543
            prompt_adapter_request: Prompt Adapter request to use for
544
                generation, if any.
545
546
547
548

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
549
550
551
552
553

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
554
        """
555
556
557
558
559
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

560
        if prompt_token_ids is not None:
561
            parsed_prompts = self._convert_v1_inputs(
562
563
564
565
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
566
567
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
568

569
570
571
572
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

573
        self._validate_and_add_requests(
574
            prompts=parsed_prompts,
575
576
            params=pooling_params,
            lora_request=lora_request,
577
            prompt_adapter_request=prompt_adapter_request,
578
579
        )

580
581
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
582

583
584
585
586
587
588
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

589
590
    # LEGACY
    def _convert_v1_inputs(
591
592
        self,
        prompts: Optional[Union[str, List[str]]],
593
594
595
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
596

597
598
599
600
601
602
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
603

604
        num_requests = None
605
606
        if prompts is not None:
            num_requests = len(prompts)
607
608
609
610
611
612
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

613
            num_requests = len(prompt_token_ids)
614
615
616
617
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

618
        parsed_prompts: List[PromptType] = []
619
        for i in range(num_requests):
620
            item: PromptType
621

622
            if prompts is not None:
623
624
625
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
626
            else:
627
                raise AssertionError
628

629
            parsed_prompts.append(item)
630

631
        return parsed_prompts
632
633
634

    def _validate_and_add_requests(
        self,
635
        prompts: Union[PromptType, Sequence[PromptType]],
636
637
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
638
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
639
        prompt_adapter_request: Optional[PromptAdapterRequest],
640
        guided_options: Optional[GuidedDecodingRequest] = None,
641
    ) -> None:
642
        if isinstance(prompts, (str, dict)):
643
            # Convert a single prompt to a list.
644
            prompts = [prompts]
645

646
        num_requests = len(prompts)
647
648
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
649
                             "must be the same.")
650
651
652
653
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
654

655
656
657
658
659
660
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
                self._add_guided_processor(sp, guided_options)

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
661

Zhuohan Li's avatar
Zhuohan Li committed
662
        # Add requests to the engine.
663
        for i, prompt in enumerate(prompts):
664
            self._add_request(
665
                prompt,
666
                params[i] if isinstance(params, Sequence) else params,
667
668
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
669
670
                prompt_adapter_request=prompt_adapter_request,
            )
671

672
    def _add_request(
nunjunj's avatar
nunjunj committed
673
        self,
674
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
675
        params: Union[SamplingParams, PoolingParams],
676
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
677
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
678
679
    ) -> None:
        request_id = str(next(self.request_counter))
680
681
        self.llm_engine.add_request(
            request_id,
682
            prompt,
683
684
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
685
686
            prompt_adapter_request=prompt_adapter_request,
        )
687

688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
    def _add_guided_processor(
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
        if guided_options:
            if guided_options.guided_decoding_backend is None:
                decoding_config = self.llm_engine.get_decoding_config()
                guided_options.guided_decoding_backend = (
                    decoding_config.guided_decoding_backend)
            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
                guided_options.guided_decoding_backend, guided_options,
                self.get_tokenizer())
            if guided_logits_processor:
                if params.logits_processors is None:
                    params.logits_processors = []
                params.logits_processors.append(guided_logits_processor)
        return params

706
    def _run_engine(
707
            self, *, use_tqdm: bool
708
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
709
710
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
711
            num_requests = self.llm_engine.get_num_unfinished_requests()
712
713
714
715
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
716
717
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
718
            )
719

Zhuohan Li's avatar
Zhuohan Li committed
720
        # Run the engine.
721
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
722
723
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
724
725
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
726
            for output in step_outputs:
727
                if output.finished:
728
729
                    outputs.append(output)
                    if use_tqdm:
730
731
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
732
                            assert output.prompt_token_ids is not None
733
734
735
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
736
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
737
738
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
739
740
741
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
742
                        pbar.update(1)
743

744
745
        if use_tqdm:
            pbar.close()
746
747
748
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
749
        return sorted(outputs, key=lambda x: int(x.request_id))
750
751
752
753
754
755

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()