llm.py 39.3 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
4
5
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
                    Union, cast, overload)
6

7
from tqdm import tqdm
8

9
10
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
11
from vllm.engine.arg_utils import EngineArgs, TaskOption
Woosuk Kwon's avatar
Woosuk Kwon committed
12
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
13
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
14
15
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
16
                                         parse_chat_messages)
17
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
18
from vllm.inputs.parse import parse_and_batch_prompt
19
from vllm.logger import init_logger
20
from vllm.lora.request import LoRARequest
21
22
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
23
24
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
25
from vllm.prompt_adapter.request import PromptAdapterRequest
26
27
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
28
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
29
30
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
33

34
35
logger = init_logger(__name__)

36
37

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
42
43
44
45
46
47
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
48
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
49
50
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
51
52
53
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
54
55
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
63
        quantization: The method used to quantize the model weights. Currently,
64
            we support "awq", "gptq", and "fp8" (experimental).
65
66
67
68
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
69
70
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
71
72
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
73
74
75
76
77
78
79
80
81
82
83
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
84
85
86
87
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
88
89
90
91
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
92
93
94
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
95
            When a sequence has context length larger than this, we fall back
96
97
98
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
99
        disable_custom_all_reduce: See ParallelConfig
100
101
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
102

103
104
105
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
106
    """
107

108
109
110
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

111
112
113
114
115
116
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

117
118
119
120
121
122
123
124
125
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

126
127
128
129
130
131
132
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
133
134
135
    def __init__(
        self,
        model: str,
136
        tokenizer: Optional[str] = None,
137
        tokenizer_mode: str = "auto",
138
        skip_tokenizer_init: bool = False,
139
        trust_remote_code: bool = False,
140
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
141
        dtype: str = "auto",
142
        quantization: Optional[str] = None,
143
        revision: Optional[str] = None,
144
        tokenizer_revision: Optional[str] = None,
145
146
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
147
        swap_space: float = 4,
148
        cpu_offload_gb: float = 0,
149
        enforce_eager: Optional[bool] = None,
150
151
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
152
        disable_custom_all_reduce: bool = False,
153
        disable_async_output_proc: bool = False,
154
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
155
156
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
157
158
        **kwargs,
    ) -> None:
159
160
161
162
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
163
        it defaults to False.
164
165
        '''

166
167
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
168

Zhuohan Li's avatar
Zhuohan Li committed
169
        engine_args = EngineArgs(
170
            model=model,
171
            task=task,
172
            tokenizer=tokenizer,
173
            tokenizer_mode=tokenizer_mode,
174
            skip_tokenizer_init=skip_tokenizer_init,
175
            trust_remote_code=trust_remote_code,
176
177
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
178
            quantization=quantization,
179
            revision=revision,
180
            tokenizer_revision=tokenizer_revision,
181
182
183
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
184
            cpu_offload_gb=cpu_offload_gb,
185
186
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
187
            max_seq_len_to_capture=max_seq_len_to_capture,
188
            disable_custom_all_reduce=disable_custom_all_reduce,
189
            disable_async_output_proc=disable_async_output_proc,
190
            mm_processor_kwargs=mm_processor_kwargs,
191
192
            **kwargs,
        )
yhu422's avatar
yhu422 committed
193
194
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
195
196
        self.request_counter = Counter()

197
198
199
200
201
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
202

203
204
205
206
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
207
            tokenizer_group.tokenizer = tokenizer
208
        else:
209
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
210

211
212
213
214
215
216
217
218
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
219
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
220
221
222
223
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
224
225
    def generate(
        self,
226
        prompts: List[str],
227
228
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
229
        prompt_token_ids: Optional[List[List[int]]] = None,
230
        use_tqdm: bool = True,
231
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
232
233
234
235
236
237
238
239
240
241
242
243
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
244
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
245
246
247
248
249
250
251
252
253
254
255
256
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
257
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
258
259
260
261
262
263
264
265
266
267
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
268
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
269
270
271
272
273
274
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
275
276
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
277
278
279
280
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
281
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
282
283
284
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
285
286
287
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
288
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
289
    )
290
291
    def generate(
        self,
292
        prompts: Union[Union[PromptType, Sequence[PromptType]],
293
294
295
296
297
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
298
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
299
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
300
        guided_options_request: Optional[Union[LLMGuidedOptions,
301
302
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
303
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
304
305
        """Generates the completions for the input prompts.

306
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
307
308
309
310
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
311
312
313
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
314
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
315
316
317
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
318
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
319
            use_tqdm: Whether to use tqdm to display the progress bar.
320
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
321
            prompt_adapter_request: Prompt Adapter request to use for
322
                generation, if any.
323
324
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
325
326

        Returns:
nunjunj's avatar
nunjunj committed
327
            A list of ``RequestOutput`` objects containing the
328
            generated completions in the same order as the input prompts.
329
330
331
332
333

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
334
        """
335
336
337
        task = self.llm_engine.model_config.task
        if task != "generate":
            messages = [
338
                "LLM.generate() is only supported for (conditional) generation "
339
340
341
342
343
344
345
346
347
348
349
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "generate" in supported_tasks:
                messages.append(
                    "Your model supports the 'generate' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task generate`.")

            raise ValueError(" ".join(messages))
350

351
        if prompt_token_ids is not None:
352
            parsed_prompts = self._convert_v1_inputs(
353
354
355
356
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
357
358
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
359

360
361
362
363
364
365
366
367
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

368
369
370
371
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

372
        self._validate_and_add_requests(
373
            prompts=parsed_prompts,
374
375
            params=sampling_params,
            lora_request=lora_request,
376
            prompt_adapter_request=prompt_adapter_request,
377
378
            guided_options=guided_options_request,
            priority=priority)
379

380
381
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
382

383
384
385
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
386
        params: BeamSearchParams,
387
388
389
390
391
392
393
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
394
395
            params: The beam search parameters.

396
397
398
399
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

400
401
402
403
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
404
405
406
407
408
409
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
410

411
412
413
414
415
416
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
417
                                            temperature=temperature)
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
463
                                logprobs=current_beam.logprobs + [logprobs],
464
465
466
467
468
469
470
471
472
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
473
                                      key=sort_beams_key,
474
475
476
477
478
479
480
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
481
                                      key=sort_beams_key,
482
483
484
485
486
487
488
489
490
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
491
492
    def chat(
        self,
493
494
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
495
496
497
498
499
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
500
        add_generation_prompt: bool = True,
501
        continue_final_message: bool = False,
502
        tools: Optional[List[Dict[str, Any]]] = None,
503
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
504
505
    ) -> List[RequestOutput]:
        """
506
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
507

508
509
510
511
512
513
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
514
515

        Args:
516
517
518
            messages: A list of conversations or a single conversation. 
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
519
520
521
522
523
524
525
526
527
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
528
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
529
                to each message.
530
531
532
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be `True`
                if `add_generation_prompt` is also `True`.
533
534
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
535
536
537
538
539

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
540
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
541

542
543
544
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
545
546
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
547
        else:
548
            # messages is List[...]
549
550
551
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
552
553
554
555
556
557
558

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            tokenizer = self.get_tokenizer()
            model_config = self.llm_engine.get_model_config()

559
560
561
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
562
563
564
565
566
567
568
569
570
571
            conversation, mm_data = parse_chat_messages(
                msgs, model_config, tokenizer)

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
572
                    continue_final_message=continue_final_message,
573
574
575
576
577
578
579
580
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
581
                    continue_final_message=continue_final_message,
582
583
584
585
586
587
588
589
590
591
592
593
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

594
595
596
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

597
            prompts.append(prompt)
598

nunjunj's avatar
nunjunj committed
599
        return self.generate(
600
            prompts,
601
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
602
603
604
605
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

606
607
608
609
610
611
612
613
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
614
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
615
616
    ) -> List[EmbeddingRequestOutput]:
        ...
617

618
    @overload  # LEGACY: multi (prompt + optional token ids)
619
620
    def encode(
        self,
621
        prompts: List[str],
622
        pooling_params: Optional[Union[PoolingParams,
623
                                       Sequence[PoolingParams]]] = None,
624
625
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
626
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
627
628
629
630
631
632
633
634
635
636
637
638
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
639
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
640
641
642
643
644
645
646
647
648
649
650
651
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
652
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
653
654
655
656
657
658
659
660
661
662
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
663
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
664
665
666
667
668
669
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
670
671
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
672
673
674
675
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
676
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
677
678
679
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
680
681
682
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
683
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
684
    )
685
686
    def encode(
        self,
687
        prompts: Union[Union[PromptType, Sequence[PromptType]],
688
689
690
691
692
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
693
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
694
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
695
696
697
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

698
        This class automatically batches the given prompts, considering
699
700
701
702
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
703
704
705
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
706
707
708
709
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
710
            prompt_adapter_request: Prompt Adapter request to use for
711
                generation, if any.
712
713
714
715

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
716
717
718
719
720

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
721
        """
722
723
724
725
726
727
728
729
730
731
732
733
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.encode() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))
734

735
        if prompt_token_ids is not None:
736
            parsed_prompts = self._convert_v1_inputs(
737
738
739
740
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
741
742
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
743

744
745
746
747
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

748
        self._validate_and_add_requests(
749
            prompts=parsed_prompts,
750
751
            params=pooling_params,
            lora_request=lora_request,
752
            prompt_adapter_request=prompt_adapter_request,
753
754
        )

755
756
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
757

758
759
760
761
762
763
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

764
765
    # LEGACY
    def _convert_v1_inputs(
766
767
        self,
        prompts: Optional[Union[str, List[str]]],
768
769
770
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
771

772
773
774
775
776
777
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
778

779
        num_requests = None
780
781
        if prompts is not None:
            num_requests = len(prompts)
782
783
784
785
786
787
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

788
            num_requests = len(prompt_token_ids)
789
790
791
792
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

793
        parsed_prompts: List[PromptType] = []
794
        for i in range(num_requests):
795
            item: PromptType
796

797
            if prompts is not None:
798
799
800
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
801
            else:
802
                raise AssertionError
803

804
            parsed_prompts.append(item)
805

806
        return parsed_prompts
807
808
809

    def _validate_and_add_requests(
        self,
810
        prompts: Union[PromptType, Sequence[PromptType]],
811
812
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
813
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
814
        prompt_adapter_request: Optional[PromptAdapterRequest],
815
        guided_options: Optional[GuidedDecodingRequest] = None,
816
        priority: Optional[List[int]] = None,
817
    ) -> None:
818
819
820
821
822
823
824
825
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

826
        if isinstance(prompts, (str, dict)):
827
            # Convert a single prompt to a list.
828
            prompts = [prompts]
829

830
        num_requests = len(prompts)
831
832
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
833
                             "must be the same.")
834
835
836
837
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
838

839
840
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
841
                self._add_guided_params(sp, guided_options)
842
843
844

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
845

Zhuohan Li's avatar
Zhuohan Li committed
846
        # Add requests to the engine.
847
        for i, prompt in enumerate(prompts):
848
            self._add_request(
849
                prompt,
850
                params[i] if isinstance(params, Sequence) else params,
851
852
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
853
                prompt_adapter_request=prompt_adapter_request,
854
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
855
            )
856

857
    def _add_request(
nunjunj's avatar
nunjunj committed
858
        self,
859
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
860
        params: Union[SamplingParams, PoolingParams],
861
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
862
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
863
        priority: int = 0,
864
865
    ) -> None:
        request_id = str(next(self.request_counter))
866
867
        self.llm_engine.add_request(
            request_id,
868
            prompt,
869
870
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
871
            prompt_adapter_request=prompt_adapter_request,
872
            priority=priority,
nunjunj's avatar
nunjunj committed
873
        )
874

875
    def _add_guided_params(
876
877
878
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
894
895
        return params

896
    def _run_engine(
897
            self, *, use_tqdm: bool
898
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
899
900
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
901
            num_requests = self.llm_engine.get_num_unfinished_requests()
902
903
904
905
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
906
907
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
908
            )
909

Zhuohan Li's avatar
Zhuohan Li committed
910
        # Run the engine.
911
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
912
913
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
914
915
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
916
            for output in step_outputs:
917
                if output.finished:
918
919
                    outputs.append(output)
                    if use_tqdm:
920
921
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
922
                            assert output.prompt_token_ids is not None
923
924
925
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
926
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
927
928
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
929
930
931
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
932
                        pbar.update(1)
933

934
935
        if use_tqdm:
            pbar.close()
936
937
938
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
939
        return sorted(outputs, key=lambda x: int(x.request_id))
940
941
942

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()