"csrc/vscode:/vscode.git/clone" did not exist on "c6703d1e0d488a09dc76562c7335306f0f3486c0"
llm.py 37.5 KB
Newer Older
1
import itertools
2
from contextlib import contextmanager
3
4
5
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
                    Union, cast, overload)
6

7
from tqdm import tqdm
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
11
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
12
13
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
14
                                         parse_chat_messages)
15
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
16
from vllm.inputs.parse import parse_and_batch_prompt
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
20
21
from vllm.model_executor.guided_decoding import (
    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
22
23
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
24
from vllm.prompt_adapter.request import PromptAdapterRequest
25
from vllm.sampling_params import RequestOutputKind, SamplingParams
26
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
27
28
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.utils import Counter, deprecate_kwargs, is_list_of
31

32
33
logger = init_logger(__name__)

34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataclass
class BeamSearchSequence:
    """A sequence for beam search.
    It keeps track of the tokens and the log probability of the sequence.
    The text field is optional and will only be filled when the sequence is
    about to be returned to the user.
    """
    # The tokens includes the prompt.
    tokens: List[int]
    cum_logprob: float = 0.0
    text: Optional[str] = None


@dataclass
class BeamSearchOutput:
    """The output of beam search.
    It contains the list of the best beam search sequences.
    The length of the list is equal to the beam width.
    """
    sequences: List[BeamSearchSequence]


class BeamSearchInstance:

    def __init__(self, prompt_tokens: List[int]):
        self.beams: List[BeamSearchSequence] = [
            BeamSearchSequence(tokens=prompt_tokens)
        ]
        self.completed: List[BeamSearchSequence] = []


66
class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72
73
74
75
76
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
77
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
78
79
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
80
81
82
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
83
84
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
91
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
92
        quantization: The method used to quantize the model weights. Currently,
93
            we support "awq", "gptq", and "fp8" (experimental).
94
95
96
97
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
98
99
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
100
101
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
102
103
104
105
106
107
108
109
110
111
112
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
113
114
115
116
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
117
118
119
120
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
121
122
123
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
124
            When a sequence has context length larger than this, we fall back
125
126
127
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
128
        disable_custom_all_reduce: See ParallelConfig
129
130
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
131

132
133
134
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
135
    """
136

137
138
139
140
141
142
143
144
145
146
147
148
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

149
150
151
    def __init__(
        self,
        model: str,
152
        tokenizer: Optional[str] = None,
153
        tokenizer_mode: str = "auto",
154
        skip_tokenizer_init: bool = False,
155
        trust_remote_code: bool = False,
156
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        dtype: str = "auto",
158
        quantization: Optional[str] = None,
159
        revision: Optional[str] = None,
160
        tokenizer_revision: Optional[str] = None,
161
162
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
163
        swap_space: float = 4,
164
        cpu_offload_gb: float = 0,
165
        enforce_eager: Optional[bool] = None,
166
167
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
168
        disable_custom_all_reduce: bool = False,
169
        disable_async_output_proc: bool = False,
170
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
171
172
        **kwargs,
    ) -> None:
173
174
175
176
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
177
        it defaults to False.
178
179
        '''

180
181
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
nunjunj's avatar
nunjunj committed
182
183
184
185
186
187
        removed_vision_keys = (
            "image_token_id",
            "image_feature_size",
            "image_input_shape",
            "image_input_type",
        )
188
189
190
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
191
        engine_args = EngineArgs(
192
            model=model,
193
            tokenizer=tokenizer,
194
            tokenizer_mode=tokenizer_mode,
195
            skip_tokenizer_init=skip_tokenizer_init,
196
            trust_remote_code=trust_remote_code,
197
198
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
199
            quantization=quantization,
200
            revision=revision,
201
            tokenizer_revision=tokenizer_revision,
202
203
204
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
205
            cpu_offload_gb=cpu_offload_gb,
206
207
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
208
            max_seq_len_to_capture=max_seq_len_to_capture,
209
            disable_custom_all_reduce=disable_custom_all_reduce,
210
            disable_async_output_proc=disable_async_output_proc,
211
            mm_processor_kwargs=mm_processor_kwargs,
212
213
            **kwargs,
        )
yhu422's avatar
yhu422 committed
214
215
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
216
217
        self.request_counter = Counter()

218
219
220
221
222
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
223

224
225
226
227
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
228
            tokenizer_group.tokenizer = tokenizer
229
        else:
230
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
231

232
233
234
235
236
237
238
239
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
240
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
241
242
243
244
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
245
246
    def generate(
        self,
247
        prompts: List[str],
248
249
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
250
        prompt_token_ids: Optional[List[List[int]]] = None,
251
        use_tqdm: bool = True,
252
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
253
254
255
256
257
258
259
260
261
262
263
264
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
265
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
266
267
268
269
270
271
272
273
274
275
276
277
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
278
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
279
280
281
282
283
284
285
286
287
288
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
289
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
290
291
292
293
294
295
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
296
297
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
298
299
300
301
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
302
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
303
304
305
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
306
307
308
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
309
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
310
    )
311
312
    def generate(
        self,
313
        prompts: Union[Union[PromptType, Sequence[PromptType]],
314
315
316
317
318
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
319
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
320
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
321
        guided_options_request: Optional[Union[LLMGuidedOptions,
322
323
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
324
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
325
326
        """Generates the completions for the input prompts.

327
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
330
331
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
332
333
334
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
335
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
336
337
338
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
339
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
340
            use_tqdm: Whether to use tqdm to display the progress bar.
341
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
342
            prompt_adapter_request: Prompt Adapter request to use for
343
                generation, if any.
344
345
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
346
347

        Returns:
nunjunj's avatar
nunjunj committed
348
            A list of ``RequestOutput`` objects containing the
349
            generated completions in the same order as the input prompts.
350
351
352
353
354

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
355
        """
356
357
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
358
359
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
360

361
        if prompt_token_ids is not None:
362
            parsed_prompts = self._convert_v1_inputs(
363
364
365
366
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
367
368
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
369

370
371
372
373
374
375
376
377
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

378
379
380
381
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

382
        self._validate_and_add_requests(
383
            prompts=parsed_prompts,
384
385
            params=sampling_params,
            lora_request=lora_request,
386
            prompt_adapter_request=prompt_adapter_request,
387
388
            guided_options=guided_options_request,
            priority=priority)
389

390
391
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
392

393
394
395
396
397
398
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
        beam_width: int,
        max_tokens: int,
        ignore_eos: bool = False,
399
        temperature: float = 0.0,
400
401
402
403
404
405
406
407
408
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
            beam_width: The number of beams to keep at each step.
            max_tokens: The max number of tokens to generate for each prompt.
409
            temperature: The temperature to use for generation.
410
411
412
413
414
415
416
417
418
419
420
        
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
421
                                            temperature=temperature)
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
                                      key=lambda x: x.cum_logprob,
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
                                      key=lambda x: x.cum_logprob,
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
494
495
    def chat(
        self,
496
497
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
498
499
500
501
502
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
503
        add_generation_prompt: bool = True,
504
        tools: Optional[List[Dict[str, Any]]] = None,
nunjunj's avatar
nunjunj committed
505
506
    ) -> List[RequestOutput]:
        """
507
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
508

509
510
511
512
513
514
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
515
516

        Args:
517
518
519
            messages: A list of conversations or a single conversation. 
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
520
521
522
523
524
525
526
527
528
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
529
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
530
531
532
533
534
535
                to each message.

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
536
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
537

538
539
540
541
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
            list_of_messages = messages
542
        else:
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
            # messages is List[...]
            list_of_messages = [messages]

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            tokenizer = self.get_tokenizer()
            model_config = self.llm_engine.get_model_config()

            conversation, mm_data = parse_chat_messages(
                msgs, model_config, tokenizer)

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

            prompts.append(prompt)
583

nunjunj's avatar
nunjunj committed
584
        return self.generate(
585
            prompts,
586
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
587
588
589
590
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

591
592
593
594
595
596
597
598
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
599
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
600
601
    ) -> List[EmbeddingRequestOutput]:
        ...
602

603
    @overload  # LEGACY: multi (prompt + optional token ids)
604
605
    def encode(
        self,
606
        prompts: List[str],
607
        pooling_params: Optional[Union[PoolingParams,
608
                                       Sequence[PoolingParams]]] = None,
609
610
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
611
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
612
613
614
615
616
617
618
619
620
621
622
623
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
624
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
625
626
627
628
629
630
631
632
633
634
635
636
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
637
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
638
639
640
641
642
643
644
645
646
647
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
648
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
649
650
651
652
653
654
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
655
656
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
657
658
659
660
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
661
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
662
663
664
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
665
666
667
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
668
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
669
    )
670
671
    def encode(
        self,
672
        prompts: Union[Union[PromptType, Sequence[PromptType]],
673
674
675
676
677
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
678
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
679
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
680
681
682
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

683
        This class automatically batches the given prompts, considering
684
685
686
687
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
688
689
690
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
691
692
693
694
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
695
            prompt_adapter_request: Prompt Adapter request to use for
696
                generation, if any.
697
698
699
700

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
701
702
703
704
705

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
706
        """
707
708
709
710
711
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

712
        if prompt_token_ids is not None:
713
            parsed_prompts = self._convert_v1_inputs(
714
715
716
717
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
718
719
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
720

721
722
723
724
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

725
        self._validate_and_add_requests(
726
            prompts=parsed_prompts,
727
728
            params=pooling_params,
            lora_request=lora_request,
729
            prompt_adapter_request=prompt_adapter_request,
730
731
        )

732
733
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
734

735
736
737
738
739
740
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

741
742
    # LEGACY
    def _convert_v1_inputs(
743
744
        self,
        prompts: Optional[Union[str, List[str]]],
745
746
747
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
748

749
750
751
752
753
754
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
755

756
        num_requests = None
757
758
        if prompts is not None:
            num_requests = len(prompts)
759
760
761
762
763
764
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

765
            num_requests = len(prompt_token_ids)
766
767
768
769
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

770
        parsed_prompts: List[PromptType] = []
771
        for i in range(num_requests):
772
            item: PromptType
773

774
            if prompts is not None:
775
776
777
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
778
            else:
779
                raise AssertionError
780

781
            parsed_prompts.append(item)
782

783
        return parsed_prompts
784
785
786

    def _validate_and_add_requests(
        self,
787
        prompts: Union[PromptType, Sequence[PromptType]],
788
789
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
790
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
791
        prompt_adapter_request: Optional[PromptAdapterRequest],
792
        guided_options: Optional[GuidedDecodingRequest] = None,
793
        priority: Optional[List[int]] = None,
794
    ) -> None:
795
        if isinstance(prompts, (str, dict)):
796
            # Convert a single prompt to a list.
797
            prompts = [prompts]
798

799
        num_requests = len(prompts)
800
801
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
802
                             "must be the same.")
803
804
805
806
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
807

808
809
810
811
812
813
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
                self._add_guided_processor(sp, guided_options)

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
814

Zhuohan Li's avatar
Zhuohan Li committed
815
        # Add requests to the engine.
816
        for i, prompt in enumerate(prompts):
817
            self._add_request(
818
                prompt,
819
                params[i] if isinstance(params, Sequence) else params,
820
821
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
822
                prompt_adapter_request=prompt_adapter_request,
823
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
824
            )
825

826
    def _add_request(
nunjunj's avatar
nunjunj committed
827
        self,
828
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
829
        params: Union[SamplingParams, PoolingParams],
830
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
831
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
832
        priority: int = 0,
833
834
    ) -> None:
        request_id = str(next(self.request_counter))
835
836
        self.llm_engine.add_request(
            request_id,
837
            prompt,
838
839
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
840
            prompt_adapter_request=prompt_adapter_request,
841
            priority=priority,
nunjunj's avatar
nunjunj committed
842
        )
843

844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
    def _add_guided_processor(
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
        if guided_options:
            if guided_options.guided_decoding_backend is None:
                decoding_config = self.llm_engine.get_decoding_config()
                guided_options.guided_decoding_backend = (
                    decoding_config.guided_decoding_backend)
            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
                guided_options.guided_decoding_backend, guided_options,
                self.get_tokenizer())
            if guided_logits_processor:
                if params.logits_processors is None:
                    params.logits_processors = []
                params.logits_processors.append(guided_logits_processor)
        return params

862
    def _run_engine(
863
            self, *, use_tqdm: bool
864
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
865
866
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
867
            num_requests = self.llm_engine.get_num_unfinished_requests()
868
869
870
871
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
872
873
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
874
            )
875

Zhuohan Li's avatar
Zhuohan Li committed
876
        # Run the engine.
877
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
878
879
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
880
881
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
882
            for output in step_outputs:
883
                if output.finished:
884
885
                    outputs.append(output)
                    if use_tqdm:
886
887
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
888
                            assert output.prompt_token_ids is not None
889
890
891
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
892
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
893
894
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
895
896
897
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
898
                        pbar.update(1)
899

900
901
        if use_tqdm:
            pbar.close()
902
903
904
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
905
        return sorted(outputs, key=lambda x: int(x.request_id))
906
907
908
909
910
911

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()