llm.py 36.2 KB
Newer Older
1
import itertools
2
from contextlib import contextmanager
3
4
5
from dataclasses import dataclass
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
                    Union, cast, overload)
6

7
from tqdm import tqdm
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
10
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
11
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
12
13
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
14
                                         parse_chat_messages)
15
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
16
from vllm.inputs.parse import parse_and_batch_prompt
17
from vllm.logger import init_logger
18
from vllm.lora.request import LoRARequest
19
20
21
from vllm.model_executor.guided_decoding import (
    GuidedDecodingRequest, get_local_guided_decoding_logits_processor)
from vllm.model_executor.guided_decoding.guided_fields import LLMGuidedOptions
22
23
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
24
from vllm.prompt_adapter.request import PromptAdapterRequest
25
from vllm.sampling_params import RequestOutputKind, SamplingParams
26
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
27
28
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
29
from vllm.usage.usage_lib import UsageContext
30
from vllm.utils import Counter, deprecate_kwargs, is_list_of
31

32
33
logger = init_logger(__name__)

34

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@dataclass
class BeamSearchSequence:
    """A sequence for beam search.
    It keeps track of the tokens and the log probability of the sequence.
    The text field is optional and will only be filled when the sequence is
    about to be returned to the user.
    """
    # The tokens includes the prompt.
    tokens: List[int]
    cum_logprob: float = 0.0
    text: Optional[str] = None


@dataclass
class BeamSearchOutput:
    """The output of beam search.
    It contains the list of the best beam search sequences.
    The length of the list is equal to the beam width.
    """
    sequences: List[BeamSearchSequence]


class BeamSearchInstance:

    def __init__(self, prompt_tokens: List[int]):
        self.beams: List[BeamSearchSequence] = [
            BeamSearchSequence(tokens=prompt_tokens)
        ]
        self.completed: List[BeamSearchSequence] = []


66
class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
67
68
69
70
71
72
73
74
75
76
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
77
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
78
79
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
80
81
82
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
83
84
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
85
86
87
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
88
89
90
91
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
92
        quantization: The method used to quantize the model weights. Currently,
93
            we support "awq", "gptq", and "fp8" (experimental).
94
95
96
97
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
98
99
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
100
101
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
102
103
104
105
106
107
108
109
110
111
112
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
113
114
115
116
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
117
118
119
120
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
121
122
123
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
124
            When a sequence has context length larger than this, we fall back
125
126
127
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
128
        disable_custom_all_reduce: See ParallelConfig
129
130
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
131

132
133
134
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
135
    """
136

137
138
139
140
141
142
143
144
145
146
147
148
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

149
150
151
    def __init__(
        self,
        model: str,
152
        tokenizer: Optional[str] = None,
153
        tokenizer_mode: str = "auto",
154
        skip_tokenizer_init: bool = False,
155
        trust_remote_code: bool = False,
156
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        dtype: str = "auto",
158
        quantization: Optional[str] = None,
159
        revision: Optional[str] = None,
160
        tokenizer_revision: Optional[str] = None,
161
162
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
163
        swap_space: float = 4,
164
        cpu_offload_gb: float = 0,
165
        enforce_eager: Optional[bool] = None,
166
167
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
168
        disable_custom_all_reduce: bool = False,
169
        disable_async_output_proc: bool = False,
170
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
171
172
        **kwargs,
    ) -> None:
173
174
175
176
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
177
        it defaults to False.
178
179
        '''

180
181
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
nunjunj's avatar
nunjunj committed
182
183
184
185
186
187
        removed_vision_keys = (
            "image_token_id",
            "image_feature_size",
            "image_input_shape",
            "image_input_type",
        )
188
189
190
        if any(k in kwargs for k in removed_vision_keys):
            raise TypeError(
                "There is no need to pass vision-related arguments anymore.")
Zhuohan Li's avatar
Zhuohan Li committed
191
        engine_args = EngineArgs(
192
            model=model,
193
            tokenizer=tokenizer,
194
            tokenizer_mode=tokenizer_mode,
195
            skip_tokenizer_init=skip_tokenizer_init,
196
            trust_remote_code=trust_remote_code,
197
198
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
199
            quantization=quantization,
200
            revision=revision,
201
            tokenizer_revision=tokenizer_revision,
202
203
204
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
205
            cpu_offload_gb=cpu_offload_gb,
206
207
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
208
            max_seq_len_to_capture=max_seq_len_to_capture,
209
            disable_custom_all_reduce=disable_custom_all_reduce,
210
            disable_async_output_proc=disable_async_output_proc,
211
            mm_processor_kwargs=mm_processor_kwargs,
212
213
            **kwargs,
        )
yhu422's avatar
yhu422 committed
214
215
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
216
217
        self.request_counter = Counter()

218
219
220
221
222
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
223

224
225
226
227
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
228
            tokenizer_group.tokenizer = tokenizer
229
        else:
230
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
231

232
233
234
235
236
237
238
239
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
240
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
241
242
243
244
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
245
246
    def generate(
        self,
247
        prompts: List[str],
248
249
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
250
        prompt_token_ids: Optional[List[List[int]]] = None,
251
        use_tqdm: bool = True,
252
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
253
254
255
256
257
258
259
260
261
262
263
264
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
265
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
266
267
268
269
270
271
272
273
274
275
276
277
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
278
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
279
280
281
282
283
284
285
286
287
288
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
289
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
290
291
292
293
294
295
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
296
297
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
        /,  # We may enable `inputs` keyword after removing the old API
298
299
300
301
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
302
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
303
304
305
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
306
307
308
309
310
311
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
312
313
    def generate(
        self,
314
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
315
316
317
318
319
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
320
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
321
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
322
323
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None
324
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
325
326
        """Generates the completions for the input prompts.

327
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
328
329
330
331
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
332
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
333
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
334
335
336
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
337
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
338
            use_tqdm: Whether to use tqdm to display the progress bar.
339
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
340
            prompt_adapter_request: Prompt Adapter request to use for
341
                generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
342
343

        Returns:
nunjunj's avatar
nunjunj committed
344
            A list of ``RequestOutput`` objects containing the
345
            generated completions in the same order as the input prompts.
346
347
348
349
350

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
351
        """
352
353
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
354
355
                "LLM.generate() is only supported for (conditional) generation "
                "models (XForCausalLM, XForConditionalGeneration).")
356

357
        if prompt_token_ids is not None:
358
            inputs = self._convert_v1_inputs(
359
360
361
362
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
363
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
364

365
366
367
368
369
370
371
372
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

373
374
375
376
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

377
        self._validate_and_add_requests(
378
            inputs=inputs,
379
380
            params=sampling_params,
            lora_request=lora_request,
381
382
            prompt_adapter_request=prompt_adapter_request,
            guided_options=guided_options_request)
383

384
385
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
386

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
        beam_width: int,
        max_tokens: int,
        ignore_eos: bool = False,
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
            beam_width: The number of beams to keep at each step.
            max_tokens: The max number of tokens to generate for each prompt.
        
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
                                            temperature=0.0)
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
                                      key=lambda x: x.cum_logprob,
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
                                      key=lambda x: x.cum_logprob,
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
486
487
488
489
490
491
492
493
    def chat(
        self,
        messages: List[ChatCompletionMessageParam],
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
494
        add_generation_prompt: bool = True,
495
        tools: Optional[List[Dict[str, Any]]] = None,
nunjunj's avatar
nunjunj committed
496
497
    ) -> List[RequestOutput]:
        """
498
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
499

500
501
502
503
504
505
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
506
507

        Args:
508
509
            messages: A single conversation represented as a list of messages.
                Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
510
511
512
513
514
515
516
517
518
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
519
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
520
521
522
523
524
525
526
527
528
529
                to each message.

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """

        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()

530
531
        conversation, mm_data = parse_chat_messages(messages, model_config,
                                                    tokenizer)
nunjunj's avatar
nunjunj committed
532

533
        prompt: Union[str, List[int]]
534
        if isinstance(tokenizer, MistralTokenizer):
535
            prompt = apply_mistral_chat_template(
536
537
538
539
                tokenizer,
                messages=messages,
                chat_template=chat_template,
                add_generation_prompt=add_generation_prompt,
540
                tools=tools,
541
542
            )
        else:
543
            prompt = apply_hf_chat_template(
544
545
546
547
                tokenizer,
                conversation=conversation,
                chat_template=chat_template,
                add_generation_prompt=add_generation_prompt,
548
                tools=tools,
549
            )
nunjunj's avatar
nunjunj committed
550

551
552
553
        inputs: PromptInputs
        if is_list_of(prompt, int):
            inputs = TokensPrompt(prompt_token_ids=prompt)
554
        else:
555
            inputs = TextPrompt(prompt=prompt)
556

557
        if mm_data is not None:
558
            inputs["multi_modal_data"] = mm_data
559

nunjunj's avatar
nunjunj committed
560
        return self.generate(
561
            inputs,
562
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
563
564
565
566
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

567
568
569
570
571
572
573
574
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
575
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
576
577
    ) -> List[EmbeddingRequestOutput]:
        ...
578

579
    @overload  # LEGACY: multi (prompt + optional token ids)
580
581
    def encode(
        self,
582
        prompts: List[str],
583
        pooling_params: Optional[Union[PoolingParams,
584
                                       Sequence[PoolingParams]]] = None,
585
586
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
587
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
588
589
590
591
592
593
594
595
596
597
598
599
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
600
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
601
602
603
604
605
606
607
608
609
610
611
612
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
613
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
614
615
616
617
618
619
620
621
622
623
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
624
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
625
626
627
628
629
630
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
631
632
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
        /,  # We may enable `inputs` keyword after removing the old API
633
634
635
636
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
637
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
638
639
640
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
641
642
643
644
645
646
    @deprecate_kwargs(
        "prompts",
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
        additional_message="Please use the 'inputs' parameter instead.",
    )
647
648
    def encode(
        self,
649
        prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
650
651
652
653
654
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
655
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
656
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
657
658
659
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

660
        This class automatically batches the given prompts, considering
661
662
663
664
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
665
666
667
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
                batch inference. See :class:`~vllm.inputs.PromptInputs`
                for more details about the format of each input.
668
669
670
671
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
672
            prompt_adapter_request: Prompt Adapter request to use for
673
                generation, if any.
674
675
676
677

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
678
679
680
681
682

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
683
        """
684
685
686
687
688
        if not self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.encode() is only supported for embedding models (XModel)."
            )

689
        if prompt_token_ids is not None:
690
            inputs = self._convert_v1_inputs(
691
692
693
694
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
695
            inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
696

697
698
699
700
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

701
        self._validate_and_add_requests(
702
            inputs=inputs,
703
704
            params=pooling_params,
            lora_request=lora_request,
705
            prompt_adapter_request=prompt_adapter_request,
706
707
        )

708
709
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
710

711
712
713
714
715
716
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

717
718
    # LEGACY
    def _convert_v1_inputs(
719
720
        self,
        prompts: Optional[Union[str, List[str]]],
721
722
723
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
724

725
726
727
728
729
730
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
731

732
        num_requests = None
733
734
        if prompts is not None:
            num_requests = len(prompts)
735
736
737
738
739
740
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

741
            num_requests = len(prompt_token_ids)
742
743
744
745
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

746
        inputs: List[PromptInputs] = []
747
        for i in range(num_requests):
748
            item: PromptInputs
749

750
            if prompts is not None:
751
752
753
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
754
            else:
755
                raise AssertionError
756

757
            inputs.append(item)
758

759
        return inputs
760
761
762

    def _validate_and_add_requests(
        self,
763
        inputs: Union[PromptInputs, Sequence[PromptInputs]],
764
765
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
766
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
767
        prompt_adapter_request: Optional[PromptAdapterRequest],
768
        guided_options: Optional[GuidedDecodingRequest] = None,
769
    ) -> None:
770
        if isinstance(inputs, (str, dict)):
771
            # Convert a single prompt to a list.
772
            inputs = [inputs]
773

774
        num_requests = len(inputs)
775
776
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
777
                             "must be the same.")
778
779
780
781
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
782

783
784
785
786
787
788
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
                self._add_guided_processor(sp, guided_options)

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
789

Zhuohan Li's avatar
Zhuohan Li committed
790
        # Add requests to the engine.
791
        for i, request_inputs in enumerate(inputs):
792
            self._add_request(
793
                request_inputs,
794
                params[i] if isinstance(params, Sequence) else params,
795
796
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
797
798
                prompt_adapter_request=prompt_adapter_request,
            )
799

800
    def _add_request(
nunjunj's avatar
nunjunj committed
801
        self,
802
        inputs: PromptInputs,
nunjunj's avatar
nunjunj committed
803
        params: Union[SamplingParams, PoolingParams],
804
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
805
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
806
807
    ) -> None:
        request_id = str(next(self.request_counter))
808
809
        self.llm_engine.add_request(
            request_id,
810
            inputs,
811
812
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
813
814
            prompt_adapter_request=prompt_adapter_request,
        )
815

816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
    def _add_guided_processor(
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
        if guided_options:
            if guided_options.guided_decoding_backend is None:
                decoding_config = self.llm_engine.get_decoding_config()
                guided_options.guided_decoding_backend = (
                    decoding_config.guided_decoding_backend)
            guided_logits_processor = get_local_guided_decoding_logits_processor(  #noqa
                guided_options.guided_decoding_backend, guided_options,
                self.get_tokenizer())
            if guided_logits_processor:
                if params.logits_processors is None:
                    params.logits_processors = []
                params.logits_processors.append(guided_logits_processor)
        return params

834
    def _run_engine(
835
            self, *, use_tqdm: bool
836
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
837
838
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
839
            num_requests = self.llm_engine.get_num_unfinished_requests()
840
841
842
843
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
844
845
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
846
            )
847

Zhuohan Li's avatar
Zhuohan Li committed
848
        # Run the engine.
849
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
850
851
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
852
853
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
854
            for output in step_outputs:
855
                if output.finished:
856
857
                    outputs.append(output)
                    if use_tqdm:
858
859
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
860
                            assert output.prompt_token_ids is not None
861
862
863
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
864
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
865
866
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
867
868
869
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
870
                        pbar.update(1)
871

872
873
        if use_tqdm:
            pbar.close()
874
875
876
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
877
        return sorted(outputs, key=lambda x: int(x.request_id))
878
879
880
881
882
883

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()

    def _is_embedding_model(self):
        return self.llm_engine.is_embedding_model()