llm.py 38 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
4
5
6
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
                    Union, cast, overload)
7

8
from tqdm import tqdm
9

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
12
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
13
14
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
15
                                         parse_chat_messages)
16
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
17
from vllm.inputs.parse import parse_and_batch_prompt
18
from vllm.logger import init_logger
19
from vllm.lora.request import LoRARequest
20
21
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
22
23
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
24
from vllm.prompt_adapter.request import PromptAdapterRequest
25
26
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
27
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
28
29
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
30
from vllm.usage.usage_lib import UsageContext
31
32
from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score,
                        is_list_of)
33

34
35
logger = init_logger(__name__)

36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
@dataclass
class BeamSearchSequence:
    """A sequence for beam search.
    It keeps track of the tokens and the log probability of the sequence.
    The text field is optional and will only be filled when the sequence is
    about to be returned to the user.
    """
    # The tokens includes the prompt.
    tokens: List[int]
    cum_logprob: float = 0.0
    text: Optional[str] = None


@dataclass
class BeamSearchOutput:
    """The output of beam search.
    It contains the list of the best beam search sequences.
    The length of the list is equal to the beam width.
    """
    sequences: List[BeamSearchSequence]


class BeamSearchInstance:

    def __init__(self, prompt_tokens: List[int]):
        self.beams: List[BeamSearchSequence] = [
            BeamSearchSequence(tokens=prompt_tokens)
        ]
        self.completed: List[BeamSearchSequence] = []


68
class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
72
73
74
75
76
77
78
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
79
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
80
81
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
82
83
84
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
85
86
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
87
88
89
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
90
91
92
93
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
94
        quantization: The method used to quantize the model weights. Currently,
95
            we support "awq", "gptq", and "fp8" (experimental).
96
97
98
99
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
100
101
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
102
103
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
104
105
106
107
108
109
110
111
112
113
114
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
115
116
117
118
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
119
120
121
122
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
123
124
125
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
126
            When a sequence has context length larger than this, we fall back
127
128
129
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
130
        disable_custom_all_reduce: See ParallelConfig
131
132
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
133

134
135
136
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
137
    """
138

139
140
141
142
143
144
145
146
147
148
149
150
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

151
152
153
    def __init__(
        self,
        model: str,
154
        tokenizer: Optional[str] = None,
155
        tokenizer_mode: str = "auto",
156
        skip_tokenizer_init: bool = False,
157
        trust_remote_code: bool = False,
158
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        dtype: str = "auto",
160
        quantization: Optional[str] = None,
161
        revision: Optional[str] = None,
162
        tokenizer_revision: Optional[str] = None,
163
164
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
165
        swap_space: float = 4,
166
        cpu_offload_gb: float = 0,
167
        enforce_eager: Optional[bool] = None,
168
169
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
170
        disable_custom_all_reduce: bool = False,
171
        disable_async_output_proc: bool = False,
172
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
173
174
        **kwargs,
    ) -> None:
175
176
177
178
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
179
        it defaults to False.
180
181
        '''

182
183
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
184

Zhuohan Li's avatar
Zhuohan Li committed
185
        engine_args = EngineArgs(
186
            model=model,
187
            tokenizer=tokenizer,
188
            tokenizer_mode=tokenizer_mode,
189
            skip_tokenizer_init=skip_tokenizer_init,
190
            trust_remote_code=trust_remote_code,
191
192
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
193
            quantization=quantization,
194
            revision=revision,
195
            tokenizer_revision=tokenizer_revision,
196
197
198
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
199
            cpu_offload_gb=cpu_offload_gb,
200
201
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
202
            max_seq_len_to_capture=max_seq_len_to_capture,
203
            disable_custom_all_reduce=disable_custom_all_reduce,
204
            disable_async_output_proc=disable_async_output_proc,
205
            mm_processor_kwargs=mm_processor_kwargs,
206
207
            **kwargs,
        )
yhu422's avatar
yhu422 committed
208
209
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
210
211
        self.request_counter = Counter()

212
213
214
215
216
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
217

218
219
220
221
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
222
            tokenizer_group.tokenizer = tokenizer
223
        else:
224
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
225

226
227
228
229
230
231
232
233
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
234
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
235
236
237
238
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
239
240
    def generate(
        self,
241
        prompts: List[str],
242
243
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
244
        prompt_token_ids: Optional[List[List[int]]] = None,
245
        use_tqdm: bool = True,
246
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
247
248
249
250
251
252
253
254
255
256
257
258
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
259
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
260
261
262
263
264
265
266
267
268
269
270
271
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
272
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
273
274
275
276
277
278
279
280
281
282
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
283
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
284
285
286
287
288
289
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
290
291
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
292
293
294
295
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
296
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
297
298
299
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
300
301
302
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
303
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
304
    )
305
306
    def generate(
        self,
307
        prompts: Union[Union[PromptType, Sequence[PromptType]],
308
309
310
311
312
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
313
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
314
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
315
        guided_options_request: Optional[Union[LLMGuidedOptions,
316
317
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
318
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
319
320
        """Generates the completions for the input prompts.

321
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324
325
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
326
327
328
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
329
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
330
331
332
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
333
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
334
            use_tqdm: Whether to use tqdm to display the progress bar.
335
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
336
            prompt_adapter_request: Prompt Adapter request to use for
337
                generation, if any.
338
339
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341

        Returns:
nunjunj's avatar
nunjunj committed
342
            A list of ``RequestOutput`` objects containing the
343
            generated completions in the same order as the input prompts.
344
345
346
347
348

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
349
        """
350
351
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
352
353
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
354

355
        if prompt_token_ids is not None:
356
            parsed_prompts = self._convert_v1_inputs(
357
358
359
360
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
361
362
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
363

364
365
366
367
368
369
370
371
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

372
373
374
375
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

376
        self._validate_and_add_requests(
377
            prompts=parsed_prompts,
378
379
            params=sampling_params,
            lora_request=lora_request,
380
            prompt_adapter_request=prompt_adapter_request,
381
382
            guided_options=guided_options_request,
            priority=priority)
383

384
385
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
386

387
388
389
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
390
        params: BeamSearchParams,
391
392
393
394
395
396
397
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
398
399
            params: The beam search parameters.

400
401
402
403
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

404
405
406
407
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
408
409
410
411
412
413
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
414

415
416
417
418
419
420
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
421
                                            temperature=temperature)
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
476
                                      key=sort_beams_key,
477
478
479
480
481
482
483
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
484
                                      key=sort_beams_key,
485
486
487
488
489
490
491
492
493
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
494
495
    def chat(
        self,
496
497
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
498
499
500
501
502
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
503
        add_generation_prompt: bool = True,
504
        continue_final_message: bool = False,
505
        tools: Optional[List[Dict[str, Any]]] = None,
nunjunj's avatar
nunjunj committed
506
507
    ) -> List[RequestOutput]:
        """
508
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
509

510
511
512
513
514
515
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
516
517

        Args:
518
519
520
            messages: A list of conversations or a single conversation. 
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
521
522
523
524
525
526
527
528
529
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
530
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
531
                to each message.
532
533
534
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be `True`
                if `add_generation_prompt` is also `True`.
nunjunj's avatar
nunjunj committed
535
536
537
538
539

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
540
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
541

542
543
544
545
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
            list_of_messages = messages
546
        else:
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
            # messages is List[...]
            list_of_messages = [messages]

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            tokenizer = self.get_tokenizer()
            model_config = self.llm_engine.get_model_config()

            conversation, mm_data = parse_chat_messages(
                msgs, model_config, tokenizer)

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
566
                    continue_final_message=continue_final_message,
567
568
569
570
571
572
573
574
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
575
                    continue_final_message=continue_final_message,
576
577
578
579
580
581
582
583
584
585
586
587
588
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

            prompts.append(prompt)
589

nunjunj's avatar
nunjunj committed
590
        return self.generate(
591
            prompts,
592
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
593
594
595
596
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

597
598
599
600
601
602
603
604
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
605
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
606
607
    ) -> List[EmbeddingRequestOutput]:
        ...
608

609
    @overload  # LEGACY: multi (prompt + optional token ids)
610
611
    def encode(
        self,
612
        prompts: List[str],
613
        pooling_params: Optional[Union[PoolingParams,
614
                                       Sequence[PoolingParams]]] = None,
615
616
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
617
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
618
619
620
621
622
623
624
625
626
627
628
629
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
630
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
631
632
633
634
635
636
637
638
639
640
641
642
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
643
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
644
645
646
647
648
649
650
651
652
653
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
654
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
655
656
657
658
659
660
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
661
662
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
663
664
665
666
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
667
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
668
669
670
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
671
672
673
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
674
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
675
    )
676
677
    def encode(
        self,
678
        prompts: Union[Union[PromptType, Sequence[PromptType]],
679
680
681
682
683
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
684
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
685
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
686
687
688
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

689
        This class automatically batches the given prompts, considering
690
691
692
693
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
694
695
696
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
697
698
699
700
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
701
            prompt_adapter_request: Prompt Adapter request to use for
702
                generation, if any.
703
704
705
706

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
707
708
709
710
711

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
712
        """
713
714
715
716
717
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

718
        if prompt_token_ids is not None:
719
            parsed_prompts = self._convert_v1_inputs(
720
721
722
723
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
724
725
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
726

727
728
729
730
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

731
        self._validate_and_add_requests(
732
            prompts=parsed_prompts,
733
734
            params=pooling_params,
            lora_request=lora_request,
735
            prompt_adapter_request=prompt_adapter_request,
736
737
        )

738
739
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
740

741
742
743
744
745
746
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

747
748
    # LEGACY
    def _convert_v1_inputs(
749
750
        self,
        prompts: Optional[Union[str, List[str]]],
751
752
753
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
754

755
756
757
758
759
760
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
761

762
        num_requests = None
763
764
        if prompts is not None:
            num_requests = len(prompts)
765
766
767
768
769
770
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

771
            num_requests = len(prompt_token_ids)
772
773
774
775
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

776
        parsed_prompts: List[PromptType] = []
777
        for i in range(num_requests):
778
            item: PromptType
779

780
            if prompts is not None:
781
782
783
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
784
            else:
785
                raise AssertionError
786

787
            parsed_prompts.append(item)
788

789
        return parsed_prompts
790
791
792

    def _validate_and_add_requests(
        self,
793
        prompts: Union[PromptType, Sequence[PromptType]],
794
795
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
796
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
797
        prompt_adapter_request: Optional[PromptAdapterRequest],
798
        guided_options: Optional[GuidedDecodingRequest] = None,
799
        priority: Optional[List[int]] = None,
800
    ) -> None:
801
802
803
804
805
806
807
808
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

809
        if isinstance(prompts, (str, dict)):
810
            # Convert a single prompt to a list.
811
            prompts = [prompts]
812

813
        num_requests = len(prompts)
814
815
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
816
                             "must be the same.")
817
818
819
820
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
821

822
823
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
824
                self._add_guided_params(sp, guided_options)
825
826
827

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
828

Zhuohan Li's avatar
Zhuohan Li committed
829
        # Add requests to the engine.
830
        for i, prompt in enumerate(prompts):
831
            self._add_request(
832
                prompt,
833
                params[i] if isinstance(params, Sequence) else params,
834
835
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
836
                prompt_adapter_request=prompt_adapter_request,
837
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
838
            )
839

840
    def _add_request(
nunjunj's avatar
nunjunj committed
841
        self,
842
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
843
        params: Union[SamplingParams, PoolingParams],
844
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
845
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
846
        priority: int = 0,
847
848
    ) -> None:
        request_id = str(next(self.request_counter))
849
850
        self.llm_engine.add_request(
            request_id,
851
            prompt,
852
853
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
854
            prompt_adapter_request=prompt_adapter_request,
855
            priority=priority,
nunjunj's avatar
nunjunj committed
856
        )
857

858
    def _add_guided_params(
859
860
861
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
877
878
        return params

879
    def _run_engine(
880
            self, *, use_tqdm: bool
881
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
882
883
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
884
            num_requests = self.llm_engine.get_num_unfinished_requests()
885
886
887
888
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
889
890
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
891
            )
892

Zhuohan Li's avatar
Zhuohan Li committed
893
        # Run the engine.
894
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
895
896
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
897
898
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
899
            for output in step_outputs:
900
                if output.finished:
901
902
                    outputs.append(output)
                    if use_tqdm:
903
904
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
905
                            assert output.prompt_token_ids is not None
906
907
908
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
909
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
910
911
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
912
913
914
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
915
                        pbar.update(1)
916

917
918
        if use_tqdm:
            pbar.close()
919
920
921
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
922
        return sorted(outputs, key=lambda x: int(x.request_id))
923
924
925
926
927
928

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()