"vllm/lora/model_manager.py" did not exist on "17edd8a807019c8d1e58634aecb1de7984e8d467"
llm.py 37.2 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
4
5
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
                    Union, cast, overload)
6

7
from tqdm import tqdm
8

9
10
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
Woosuk Kwon's avatar
Woosuk Kwon committed
11
12
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
13
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
14
15
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
16
                                         parse_chat_messages)
17
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
18
from vllm.inputs.parse import parse_and_batch_prompt
19
from vllm.logger import init_logger
20
from vllm.lora.request import LoRARequest
21
22
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
23
24
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
25
from vllm.prompt_adapter.request import PromptAdapterRequest
26
27
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
28
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
29
30
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.utils import Counter, deprecate_kwargs, is_list_of
33

34
35
logger = init_logger(__name__)

36
37

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
38
39
40
41
42
43
44
45
46
47
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
48
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
49
50
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
51
52
53
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
54
55
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
56
57
58
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
59
60
61
62
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
63
        quantization: The method used to quantize the model weights. Currently,
64
            we support "awq", "gptq", and "fp8" (experimental).
65
66
67
68
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
69
70
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
71
72
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
73
74
75
76
77
78
79
80
81
82
83
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
84
85
86
87
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
88
89
90
91
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
92
93
94
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
95
            When a sequence has context length larger than this, we fall back
96
97
98
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
99
        disable_custom_all_reduce: See ParallelConfig
100
101
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
102

103
104
105
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
106
    """
107

108
109
110
111
112
113
114
115
116
117
118
119
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

120
121
122
    def __init__(
        self,
        model: str,
123
        tokenizer: Optional[str] = None,
124
        tokenizer_mode: str = "auto",
125
        skip_tokenizer_init: bool = False,
126
        trust_remote_code: bool = False,
127
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
128
        dtype: str = "auto",
129
        quantization: Optional[str] = None,
130
        revision: Optional[str] = None,
131
        tokenizer_revision: Optional[str] = None,
132
133
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
134
        swap_space: float = 4,
135
        cpu_offload_gb: float = 0,
136
        enforce_eager: Optional[bool] = None,
137
138
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
139
        disable_custom_all_reduce: bool = False,
140
        disable_async_output_proc: bool = False,
141
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
142
143
        **kwargs,
    ) -> None:
144
145
146
147
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
148
        it defaults to False.
149
150
        '''

151
152
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
153

Zhuohan Li's avatar
Zhuohan Li committed
154
        engine_args = EngineArgs(
155
            model=model,
156
            tokenizer=tokenizer,
157
            tokenizer_mode=tokenizer_mode,
158
            skip_tokenizer_init=skip_tokenizer_init,
159
            trust_remote_code=trust_remote_code,
160
161
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
162
            quantization=quantization,
163
            revision=revision,
164
            tokenizer_revision=tokenizer_revision,
165
166
167
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
168
            cpu_offload_gb=cpu_offload_gb,
169
170
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
171
            max_seq_len_to_capture=max_seq_len_to_capture,
172
            disable_custom_all_reduce=disable_custom_all_reduce,
173
            disable_async_output_proc=disable_async_output_proc,
174
            mm_processor_kwargs=mm_processor_kwargs,
175
176
            **kwargs,
        )
yhu422's avatar
yhu422 committed
177
178
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
179
180
        self.request_counter = Counter()

181
182
183
184
185
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
186

187
188
189
190
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
191
            tokenizer_group.tokenizer = tokenizer
192
        else:
193
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
194

195
196
197
198
199
200
201
202
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
203
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
204
205
206
207
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
208
209
    def generate(
        self,
210
        prompts: List[str],
211
212
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
213
        prompt_token_ids: Optional[List[List[int]]] = None,
214
        use_tqdm: bool = True,
215
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
216
217
218
219
220
221
222
223
224
225
226
227
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
228
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
229
230
231
232
233
234
235
236
237
238
239
240
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
241
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
242
243
244
245
246
247
248
249
250
251
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
252
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
253
254
255
256
257
258
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
259
260
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
261
262
263
264
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
265
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
266
267
268
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
269
270
271
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
272
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
273
    )
274
275
    def generate(
        self,
276
        prompts: Union[Union[PromptType, Sequence[PromptType]],
277
278
279
280
281
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
282
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
283
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
284
        guided_options_request: Optional[Union[LLMGuidedOptions,
285
286
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
287
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
288
289
        """Generates the completions for the input prompts.

290
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
291
292
293
294
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
295
296
297
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
298
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
299
300
301
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
302
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
303
            use_tqdm: Whether to use tqdm to display the progress bar.
304
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
305
            prompt_adapter_request: Prompt Adapter request to use for
306
                generation, if any.
307
308
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
309
310

        Returns:
nunjunj's avatar
nunjunj committed
311
            A list of ``RequestOutput`` objects containing the
312
            generated completions in the same order as the input prompts.
313
314
315
316
317

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
318
        """
319
320
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
321
322
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
323

324
        if prompt_token_ids is not None:
325
            parsed_prompts = self._convert_v1_inputs(
326
327
328
329
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
330
331
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
332

333
334
335
336
337
338
339
340
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

341
342
343
344
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

345
        self._validate_and_add_requests(
346
            prompts=parsed_prompts,
347
348
            params=sampling_params,
            lora_request=lora_request,
349
            prompt_adapter_request=prompt_adapter_request,
350
351
            guided_options=guided_options_request,
            priority=priority)
352

353
354
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
355

356
357
358
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
359
        params: BeamSearchParams,
360
361
362
363
364
365
366
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
367
368
            params: The beam search parameters.

369
370
371
372
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

373
374
375
376
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
377
378
379
380
381
382
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
383

384
385
386
387
388
389
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
390
                                            temperature=temperature)
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
445
                                      key=sort_beams_key,
446
447
448
449
450
451
452
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
453
                                      key=sort_beams_key,
454
455
456
457
458
459
460
461
462
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
463
464
    def chat(
        self,
465
466
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
467
468
469
470
471
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
472
        add_generation_prompt: bool = True,
473
        continue_final_message: bool = False,
474
        tools: Optional[List[Dict[str, Any]]] = None,
nunjunj's avatar
nunjunj committed
475
476
    ) -> List[RequestOutput]:
        """
477
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
478

479
480
481
482
483
484
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
485
486

        Args:
487
488
489
            messages: A list of conversations or a single conversation. 
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
490
491
492
493
494
495
496
497
498
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
499
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
500
                to each message.
501
502
503
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be `True`
                if `add_generation_prompt` is also `True`.
nunjunj's avatar
nunjunj committed
504
505
506
507
508

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
509
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
510

511
512
513
514
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
            list_of_messages = messages
515
        else:
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
            # messages is List[...]
            list_of_messages = [messages]

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            tokenizer = self.get_tokenizer()
            model_config = self.llm_engine.get_model_config()

            conversation, mm_data = parse_chat_messages(
                msgs, model_config, tokenizer)

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
535
                    continue_final_message=continue_final_message,
536
537
538
539
540
541
542
543
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
544
                    continue_final_message=continue_final_message,
545
546
547
548
549
550
551
552
553
554
555
556
557
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

            prompts.append(prompt)
558

nunjunj's avatar
nunjunj committed
559
        return self.generate(
560
            prompts,
561
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
562
563
564
565
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

566
567
568
569
570
571
572
573
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
574
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
575
576
    ) -> List[EmbeddingRequestOutput]:
        ...
577

578
    @overload  # LEGACY: multi (prompt + optional token ids)
579
580
    def encode(
        self,
581
        prompts: List[str],
582
        pooling_params: Optional[Union[PoolingParams,
583
                                       Sequence[PoolingParams]]] = None,
584
585
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
586
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
587
588
589
590
591
592
593
594
595
596
597
598
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
599
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
600
601
602
603
604
605
606
607
608
609
610
611
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
612
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
613
614
615
616
617
618
619
620
621
622
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
623
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
624
625
626
627
628
629
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
630
631
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
632
633
634
635
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
636
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
637
638
639
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
640
641
642
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
643
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
644
    )
645
646
    def encode(
        self,
647
        prompts: Union[Union[PromptType, Sequence[PromptType]],
648
649
650
651
652
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
653
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
654
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
655
656
657
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

658
        This class automatically batches the given prompts, considering
659
660
661
662
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
663
664
665
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
666
667
668
669
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
670
            prompt_adapter_request: Prompt Adapter request to use for
671
                generation, if any.
672
673
674
675

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
676
677
678
679
680

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
681
        """
682
683
684
685
686
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

687
        if prompt_token_ids is not None:
688
            parsed_prompts = self._convert_v1_inputs(
689
690
691
692
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
693
694
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
695

696
697
698
699
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

700
        self._validate_and_add_requests(
701
            prompts=parsed_prompts,
702
703
            params=pooling_params,
            lora_request=lora_request,
704
            prompt_adapter_request=prompt_adapter_request,
705
706
        )

707
708
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
709

710
711
712
713
714
715
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

716
717
    # LEGACY
    def _convert_v1_inputs(
718
719
        self,
        prompts: Optional[Union[str, List[str]]],
720
721
722
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
723

724
725
726
727
728
729
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
730

731
        num_requests = None
732
733
        if prompts is not None:
            num_requests = len(prompts)
734
735
736
737
738
739
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

740
            num_requests = len(prompt_token_ids)
741
742
743
744
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

745
        parsed_prompts: List[PromptType] = []
746
        for i in range(num_requests):
747
            item: PromptType
748

749
            if prompts is not None:
750
751
752
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
753
            else:
754
                raise AssertionError
755

756
            parsed_prompts.append(item)
757

758
        return parsed_prompts
759
760
761

    def _validate_and_add_requests(
        self,
762
        prompts: Union[PromptType, Sequence[PromptType]],
763
764
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
765
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
766
        prompt_adapter_request: Optional[PromptAdapterRequest],
767
        guided_options: Optional[GuidedDecodingRequest] = None,
768
        priority: Optional[List[int]] = None,
769
    ) -> None:
770
771
772
773
774
775
776
777
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

778
        if isinstance(prompts, (str, dict)):
779
            # Convert a single prompt to a list.
780
            prompts = [prompts]
781

782
        num_requests = len(prompts)
783
784
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
785
                             "must be the same.")
786
787
788
789
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
790

791
792
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
793
                self._add_guided_params(sp, guided_options)
794
795
796

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
797

Zhuohan Li's avatar
Zhuohan Li committed
798
        # Add requests to the engine.
799
        for i, prompt in enumerate(prompts):
800
            self._add_request(
801
                prompt,
802
                params[i] if isinstance(params, Sequence) else params,
803
804
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
805
                prompt_adapter_request=prompt_adapter_request,
806
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
807
            )
808

809
    def _add_request(
nunjunj's avatar
nunjunj committed
810
        self,
811
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
812
        params: Union[SamplingParams, PoolingParams],
813
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
814
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
815
        priority: int = 0,
816
817
    ) -> None:
        request_id = str(next(self.request_counter))
818
819
        self.llm_engine.add_request(
            request_id,
820
            prompt,
821
822
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
823
            prompt_adapter_request=prompt_adapter_request,
824
            priority=priority,
nunjunj's avatar
nunjunj committed
825
        )
826

827
    def _add_guided_params(
828
829
830
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
846
847
        return params

848
    def _run_engine(
849
            self, *, use_tqdm: bool
850
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
851
852
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
853
            num_requests = self.llm_engine.get_num_unfinished_requests()
854
855
856
857
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
858
859
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
860
            )
861

Zhuohan Li's avatar
Zhuohan Li committed
862
        # Run the engine.
863
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
864
865
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
866
867
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
868
            for output in step_outputs:
869
                if output.finished:
870
871
                    outputs.append(output)
                    if use_tqdm:
872
873
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
874
                            assert output.prompt_token_ids is not None
875
876
877
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
878
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
879
880
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
881
882
883
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
884
                        pbar.update(1)
885

886
887
        if use_tqdm:
            pbar.close()
888
889
890
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
891
        return sorted(outputs, key=lambda x: int(x.request_id))
892
893
894
895
896
897

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()