llm.py 40.3 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
Joe Runde's avatar
Joe Runde committed
4
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
5
                    Union, cast, overload)
6

7
from tqdm import tqdm
8

9
from vllm import envs
10
11
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
12
from vllm.engine.arg_utils import EngineArgs, TaskOption
Joe Runde's avatar
Joe Runde committed
13
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
14
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
15
16
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
17
                                         parse_chat_messages)
18
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
19
from vllm.inputs.parse import parse_and_batch_prompt
20
from vllm.logger import init_logger
21
from vllm.lora.request import LoRARequest
22
23
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
24
25
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
26
from vllm.prompt_adapter.request import PromptAdapterRequest
27
28
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
29
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
30
31
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
34

35
36
logger = init_logger(__name__)

37
38

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
45
46
47
48
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
49
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
50
51
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
52
53
54
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
55
56
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
57
58
59
60
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
68
        quantization: The method used to quantize the model weights. Currently,
69
            we support "awq", "gptq", and "fp8" (experimental).
70
71
72
73
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
74
75
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
76
77
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
78
79
80
81
82
83
84
85
86
87
88
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
89
90
91
92
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
93
94
95
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
96
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
97
            When a sequence has context length larger than this, we fall back
98
99
100
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
101
        disable_custom_all_reduce: See ParallelConfig
102
103
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
104

105
106
107
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
108
    """
109

110
111
112
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

113
114
115
116
117
118
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

119
120
121
122
123
124
125
126
127
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

128
129
130
131
132
133
134
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
135
136
137
    def __init__(
        self,
        model: str,
138
        tokenizer: Optional[str] = None,
139
        tokenizer_mode: str = "auto",
140
        skip_tokenizer_init: bool = False,
141
        trust_remote_code: bool = False,
142
        allowed_local_media_path: str = "",
143
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
144
        dtype: str = "auto",
145
        quantization: Optional[str] = None,
146
        revision: Optional[str] = None,
147
        tokenizer_revision: Optional[str] = None,
148
149
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
150
        swap_space: float = 4,
151
        cpu_offload_gb: float = 0,
152
        enforce_eager: Optional[bool] = None,
153
        max_seq_len_to_capture: int = 8192,
154
        disable_custom_all_reduce: bool = False,
155
        disable_async_output_proc: bool = False,
156
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
157
158
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
159
160
161
162
163
        pooling_type: Optional[str] = None,
        pooling_norm: Optional[bool] = None,
        pooling_softmax: Optional[bool] = None,
        pooling_step_tag_id: Optional[int] = None,
        pooling_returned_token_ids: Optional[List[int]] = None,
164
165
        **kwargs,
    ) -> None:
166
167
168
169
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
170
        it defaults to False.
171
172
        '''

173
174
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
175

Zhuohan Li's avatar
Zhuohan Li committed
176
        engine_args = EngineArgs(
177
            model=model,
178
            task=task,
179
            tokenizer=tokenizer,
180
            tokenizer_mode=tokenizer_mode,
181
            skip_tokenizer_init=skip_tokenizer_init,
182
            trust_remote_code=trust_remote_code,
183
            allowed_local_media_path=allowed_local_media_path,
184
185
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
186
            quantization=quantization,
187
            revision=revision,
188
            tokenizer_revision=tokenizer_revision,
189
190
191
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
192
            cpu_offload_gb=cpu_offload_gb,
193
            enforce_eager=enforce_eager,
194
            max_seq_len_to_capture=max_seq_len_to_capture,
195
            disable_custom_all_reduce=disable_custom_all_reduce,
196
            disable_async_output_proc=disable_async_output_proc,
197
            mm_processor_kwargs=mm_processor_kwargs,
198
199
200
201
202
            pooling_type=pooling_type,
            pooling_norm=pooling_norm,
            pooling_softmax=pooling_softmax,
            pooling_step_tag_id=pooling_step_tag_id,
            pooling_returned_token_ids=pooling_returned_token_ids,
203
204
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
205
206
207
208
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
209
            engine_args, usage_context=UsageContext.LLM_CLASS)
210
211
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
212
213
214
215
216
217
218
219
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

220
221
222
223
224
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
225

226
227
228
229
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
230
            tokenizer_group.tokenizer = tokenizer
231
        else:
232
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
233

234
235
236
237
238
239
240
241
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
242
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
243
244
245
246
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
247
248
    def generate(
        self,
249
        prompts: List[str],
250
251
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
252
        prompt_token_ids: Optional[List[List[int]]] = None,
253
        use_tqdm: bool = True,
254
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
255
256
257
258
259
260
261
262
263
264
265
266
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
267
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
268
269
270
271
272
273
274
275
276
277
278
279
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
280
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
281
282
283
284
285
286
287
288
289
290
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
291
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
292
293
294
295
296
297
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
298
299
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
300
301
302
303
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
304
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
305
306
307
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
308
309
310
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
311
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
312
    )
313
314
    def generate(
        self,
315
        prompts: Union[Union[PromptType, Sequence[PromptType]],
316
317
318
319
320
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
321
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
322
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
323
        guided_options_request: Optional[Union[LLMGuidedOptions,
324
325
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
326
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
327
328
        """Generates the completions for the input prompts.

329
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
330
331
332
333
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
334
335
336
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
337
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
338
339
340
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
341
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
342
            use_tqdm: Whether to use tqdm to display the progress bar.
343
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
344
            prompt_adapter_request: Prompt Adapter request to use for
345
                generation, if any.
346
347
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
348
349

        Returns:
nunjunj's avatar
nunjunj committed
350
            A list of ``RequestOutput`` objects containing the
351
            generated completions in the same order as the input prompts.
352
353
354
355
356

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
357
        """
358
359
360
        task = self.llm_engine.model_config.task
        if task != "generate":
            messages = [
361
                "LLM.generate() is only supported for (conditional) generation "
362
363
364
365
366
367
368
369
370
371
372
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "generate" in supported_tasks:
                messages.append(
                    "Your model supports the 'generate' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task generate`.")

            raise ValueError(" ".join(messages))
373

374
        if prompt_token_ids is not None:
375
            parsed_prompts = self._convert_v1_inputs(
376
377
378
379
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
380
381
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
382

383
384
385
386
387
388
389
390
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

391
392
393
394
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

395
        self._validate_and_add_requests(
396
            prompts=parsed_prompts,
397
398
            params=sampling_params,
            lora_request=lora_request,
399
            prompt_adapter_request=prompt_adapter_request,
400
401
            guided_options=guided_options_request,
            priority=priority)
402

403
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
404
        return self.engine_class.validate_outputs(outputs, RequestOutput)
405

406
407
408
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
409
        params: BeamSearchParams,
410
411
412
413
414
415
416
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
417
418
            params: The beam search parameters.

419
420
421
422
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

423
424
425
426
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
427
428
429
430
431
432
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
433

434
435
436
437
438
439
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
440
                                            temperature=temperature)
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
486
                                logprobs=current_beam.logprobs + [logprobs],
487
488
489
490
491
492
493
494
495
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
496
                                      key=sort_beams_key,
497
498
499
500
501
502
503
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
504
                                      key=sort_beams_key,
505
506
507
508
509
510
511
512
513
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
514
515
    def chat(
        self,
516
517
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
518
519
520
521
522
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
523
        add_generation_prompt: bool = True,
524
        continue_final_message: bool = False,
525
        tools: Optional[List[Dict[str, Any]]] = None,
526
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
527
528
    ) -> List[RequestOutput]:
        """
529
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
530

531
532
533
534
535
536
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
537
538

        Args:
539
540
541
            messages: A list of conversations or a single conversation. 
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
542
543
544
545
546
547
548
549
550
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
551
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
552
                to each message.
553
554
555
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be `True`
                if `add_generation_prompt` is also `True`.
556
557
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
558
559
560
561
562

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
563
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
564

565
566
567
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
568
569
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
570
        else:
571
            # messages is List[...]
572
573
574
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
575
576
577
578
579
580
581

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            tokenizer = self.get_tokenizer()
            model_config = self.llm_engine.get_model_config()

582
583
584
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
585
586
587
588
589
590
591
592
593
594
            conversation, mm_data = parse_chat_messages(
                msgs, model_config, tokenizer)

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
595
                    continue_final_message=continue_final_message,
596
597
598
599
600
601
602
603
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
604
                    continue_final_message=continue_final_message,
605
606
607
608
609
610
611
612
613
614
615
616
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

617
618
619
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

620
            prompts.append(prompt)
621

nunjunj's avatar
nunjunj committed
622
        return self.generate(
623
            prompts,
624
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
625
626
627
628
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

629
630
631
632
633
634
635
636
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
637
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
638
639
    ) -> List[EmbeddingRequestOutput]:
        ...
640

641
    @overload  # LEGACY: multi (prompt + optional token ids)
642
643
    def encode(
        self,
644
        prompts: List[str],
645
        pooling_params: Optional[Union[PoolingParams,
646
                                       Sequence[PoolingParams]]] = None,
647
648
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
649
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
650
651
652
653
654
655
656
657
658
659
660
661
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
662
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
663
664
665
666
667
668
669
670
671
672
673
674
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
675
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
676
677
678
679
680
681
682
683
684
685
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
686
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
687
688
689
690
691
692
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
693
694
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
695
696
697
698
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
699
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
700
701
702
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
703
704
705
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
706
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
707
    )
708
709
    def encode(
        self,
710
        prompts: Union[Union[PromptType, Sequence[PromptType]],
711
712
713
714
715
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
716
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
717
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
718
719
720
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

721
        This class automatically batches the given prompts, considering
722
723
724
725
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
726
727
728
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
729
730
731
732
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
733
            prompt_adapter_request: Prompt Adapter request to use for
734
                generation, if any.
735
736
737
738

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
739
740
741
742
743

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
744
        """
745
746
747
748
749
750
751
752
753
754
755
756
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.encode() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))
757

758
        if prompt_token_ids is not None:
759
            parsed_prompts = self._convert_v1_inputs(
760
761
762
763
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
764
765
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
766

767
768
769
770
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

771
        self._validate_and_add_requests(
772
            prompts=parsed_prompts,
773
774
            params=pooling_params,
            lora_request=lora_request,
775
            prompt_adapter_request=prompt_adapter_request,
776
777
        )

778
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
779
780
        return self.engine_class.validate_outputs(outputs,
                                                  EmbeddingRequestOutput)
781

782
783
784
785
786
787
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

788
789
    # LEGACY
    def _convert_v1_inputs(
790
791
        self,
        prompts: Optional[Union[str, List[str]]],
792
793
794
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
795

796
797
798
799
800
801
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
802

803
        num_requests = None
804
805
        if prompts is not None:
            num_requests = len(prompts)
806
807
808
809
810
811
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

812
            num_requests = len(prompt_token_ids)
813
814
815
816
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

817
        parsed_prompts: List[PromptType] = []
818
        for i in range(num_requests):
819
            item: PromptType
820

821
            if prompts is not None:
822
823
824
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
825
            else:
826
                raise AssertionError
827

828
            parsed_prompts.append(item)
829

830
        return parsed_prompts
831
832
833

    def _validate_and_add_requests(
        self,
834
        prompts: Union[PromptType, Sequence[PromptType]],
835
836
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
837
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
838
        prompt_adapter_request: Optional[PromptAdapterRequest],
839
        guided_options: Optional[GuidedDecodingRequest] = None,
840
        priority: Optional[List[int]] = None,
841
    ) -> None:
842
843
844
845
846
847
848
849
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

850
        if isinstance(prompts, (str, dict)):
851
            # Convert a single prompt to a list.
852
            prompts = [prompts]
853

854
        num_requests = len(prompts)
855
856
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
857
                             "must be the same.")
858
859
860
861
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
862

863
864
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
865
                self._add_guided_params(sp, guided_options)
866
867
868

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
869

Zhuohan Li's avatar
Zhuohan Li committed
870
        # Add requests to the engine.
871
        for i, prompt in enumerate(prompts):
872
            self._add_request(
873
                prompt,
874
                params[i] if isinstance(params, Sequence) else params,
875
876
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
877
                prompt_adapter_request=prompt_adapter_request,
878
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
879
            )
880

881
    def _add_request(
nunjunj's avatar
nunjunj committed
882
        self,
883
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
884
        params: Union[SamplingParams, PoolingParams],
885
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
886
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
887
        priority: int = 0,
888
889
    ) -> None:
        request_id = str(next(self.request_counter))
890
891
        self.llm_engine.add_request(
            request_id,
892
            prompt,
893
894
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
895
            prompt_adapter_request=prompt_adapter_request,
896
            priority=priority,
nunjunj's avatar
nunjunj committed
897
        )
898

899
    def _add_guided_params(
900
901
902
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
918
919
        return params

920
    def _run_engine(
921
            self, *, use_tqdm: bool
922
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
923
924
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
925
            num_requests = self.llm_engine.get_num_unfinished_requests()
926
927
928
929
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
930
931
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
932
            )
933

Zhuohan Li's avatar
Zhuohan Li committed
934
        # Run the engine.
935
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
936
937
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
938
939
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
940
            for output in step_outputs:
941
                if output.finished:
942
943
                    outputs.append(output)
                    if use_tqdm:
944
945
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
946
                            assert output.prompt_token_ids is not None
947
948
949
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
950
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
951
952
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
953
954
955
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
956
                        pbar.update(1)
957

958
959
        if use_tqdm:
            pbar.close()
960
961
962
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
963
        return sorted(outputs, key=lambda x: int(x.request_id))
964
965
966

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()