llm.py 39.9 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
4
5
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple,
                    Union, cast, overload)
6

7
from tqdm import tqdm
8

9
from vllm import envs
10
11
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
12
from vllm.engine.arg_utils import EngineArgs, TaskOption
nunjunj's avatar
nunjunj committed
13
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
14
15
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
16
                                         parse_chat_messages)
17
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
18
from vllm.inputs.parse import parse_and_batch_prompt
19
from vllm.logger import init_logger
20
from vllm.lora.request import LoRARequest
21
22
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
23
24
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
25
from vllm.prompt_adapter.request import PromptAdapterRequest
26
27
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
28
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
29
30
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
31
from vllm.usage.usage_lib import UsageContext
32
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
33

34
35
36
37
38
if envs.VLLM_USE_V1:
    from vllm.v1.engine.llm_engine import LLMEngine  # type: ignore
else:
    from vllm.engine.llm_engine import LLMEngine  # type: ignore

39
40
logger = init_logger(__name__)

41
42

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
46
47
48
49
50
51
52
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
53
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
54
55
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
56
57
58
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
59
60
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
68
        quantization: The method used to quantize the model weights. Currently,
69
            we support "awq", "gptq", and "fp8" (experimental).
70
71
72
73
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
74
75
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
76
77
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
78
79
80
81
82
83
84
85
86
87
88
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
89
90
91
92
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
93
94
95
96
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
97
98
99
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
100
            When a sequence has context length larger than this, we fall back
101
102
103
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
104
        disable_custom_all_reduce: See ParallelConfig
105
106
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
107

108
109
110
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
111
    """
112

113
114
115
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

116
117
118
119
120
121
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

122
123
124
125
126
127
128
129
130
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

131
132
133
134
135
136
137
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
138
139
140
    def __init__(
        self,
        model: str,
141
        tokenizer: Optional[str] = None,
142
        tokenizer_mode: str = "auto",
143
        skip_tokenizer_init: bool = False,
144
        trust_remote_code: bool = False,
145
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
146
        dtype: str = "auto",
147
        quantization: Optional[str] = None,
148
        revision: Optional[str] = None,
149
        tokenizer_revision: Optional[str] = None,
150
151
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
152
        swap_space: float = 4,
153
        cpu_offload_gb: float = 0,
154
        enforce_eager: Optional[bool] = None,
155
156
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
157
        disable_custom_all_reduce: bool = False,
158
        disable_async_output_proc: bool = False,
159
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
160
161
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
162
163
164
165
166
        pooling_type: Optional[str] = None,
        pooling_norm: Optional[bool] = None,
        pooling_softmax: Optional[bool] = None,
        pooling_step_tag_id: Optional[int] = None,
        pooling_returned_token_ids: Optional[List[int]] = None,
167
168
        **kwargs,
    ) -> None:
169
170
171
172
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
173
        it defaults to False.
174
175
        '''

176
177
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
178

Zhuohan Li's avatar
Zhuohan Li committed
179
        engine_args = EngineArgs(
180
            model=model,
181
            task=task,
182
            tokenizer=tokenizer,
183
            tokenizer_mode=tokenizer_mode,
184
            skip_tokenizer_init=skip_tokenizer_init,
185
            trust_remote_code=trust_remote_code,
186
187
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
188
            quantization=quantization,
189
            revision=revision,
190
            tokenizer_revision=tokenizer_revision,
191
192
193
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
194
            cpu_offload_gb=cpu_offload_gb,
195
196
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
197
            max_seq_len_to_capture=max_seq_len_to_capture,
198
            disable_custom_all_reduce=disable_custom_all_reduce,
199
            disable_async_output_proc=disable_async_output_proc,
200
            mm_processor_kwargs=mm_processor_kwargs,
201
202
203
204
205
            pooling_type=pooling_type,
            pooling_norm=pooling_norm,
            pooling_softmax=pooling_softmax,
            pooling_step_tag_id=pooling_step_tag_id,
            pooling_returned_token_ids=pooling_returned_token_ids,
206
207
            **kwargs,
        )
yhu422's avatar
yhu422 committed
208
209
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
210
211
        self.request_counter = Counter()

212
213
214
215
216
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
217

218
219
220
221
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
222
            tokenizer_group.tokenizer = tokenizer
223
        else:
224
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
225

226
227
228
229
230
231
232
233
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
234
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
235
236
237
238
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
239
240
    def generate(
        self,
241
        prompts: List[str],
242
243
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
244
        prompt_token_ids: Optional[List[List[int]]] = None,
245
        use_tqdm: bool = True,
246
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
247
248
249
250
251
252
253
254
255
256
257
258
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
259
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
260
261
262
263
264
265
266
267
268
269
270
271
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
272
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
273
274
275
276
277
278
279
280
281
282
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
283
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
284
285
286
287
288
289
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
290
291
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
292
293
294
295
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
296
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
297
298
299
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
300
301
302
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
303
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
304
    )
305
306
    def generate(
        self,
307
        prompts: Union[Union[PromptType, Sequence[PromptType]],
308
309
310
311
312
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
313
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
314
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
315
        guided_options_request: Optional[Union[LLMGuidedOptions,
316
317
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
318
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
319
320
        """Generates the completions for the input prompts.

321
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
322
323
324
325
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
326
327
328
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
329
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
330
331
332
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
333
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
334
            use_tqdm: Whether to use tqdm to display the progress bar.
335
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
336
            prompt_adapter_request: Prompt Adapter request to use for
337
                generation, if any.
338
339
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
340
341

        Returns:
nunjunj's avatar
nunjunj committed
342
            A list of ``RequestOutput`` objects containing the
343
            generated completions in the same order as the input prompts.
344
345
346
347
348

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
349
        """
350
351
352
        task = self.llm_engine.model_config.task
        if task != "generate":
            messages = [
353
                "LLM.generate() is only supported for (conditional) generation "
354
355
356
357
358
359
360
361
362
363
364
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "generate" in supported_tasks:
                messages.append(
                    "Your model supports the 'generate' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task generate`.")

            raise ValueError(" ".join(messages))
365

366
        if prompt_token_ids is not None:
367
            parsed_prompts = self._convert_v1_inputs(
368
369
370
371
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
372
373
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
374

375
376
377
378
379
380
381
382
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

383
384
385
386
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

387
        self._validate_and_add_requests(
388
            prompts=parsed_prompts,
389
390
            params=sampling_params,
            lora_request=lora_request,
391
            prompt_adapter_request=prompt_adapter_request,
392
393
            guided_options=guided_options_request,
            priority=priority)
394

395
396
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
397

398
399
400
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
401
        params: BeamSearchParams,
402
403
404
405
406
407
408
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
409
410
            params: The beam search parameters.

411
412
413
414
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

415
416
417
418
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
419
420
421
422
423
424
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
425

426
427
428
429
430
431
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
432
                                            temperature=temperature)
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
478
                                logprobs=current_beam.logprobs + [logprobs],
479
480
481
482
483
484
485
486
487
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
488
                                      key=sort_beams_key,
489
490
491
492
493
494
495
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
496
                                      key=sort_beams_key,
497
498
499
500
501
502
503
504
505
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
506
507
    def chat(
        self,
508
509
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
510
511
512
513
514
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
515
        add_generation_prompt: bool = True,
516
        continue_final_message: bool = False,
517
        tools: Optional[List[Dict[str, Any]]] = None,
518
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
519
520
    ) -> List[RequestOutput]:
        """
521
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
522

523
524
525
526
527
528
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
529
530

        Args:
531
532
533
            messages: A list of conversations or a single conversation. 
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
534
535
536
537
538
539
540
541
542
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
543
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
544
                to each message.
545
546
547
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be `True`
                if `add_generation_prompt` is also `True`.
548
549
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
550
551
552
553
554

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
555
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
556

557
558
559
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
560
561
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
562
        else:
563
            # messages is List[...]
564
565
566
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
567
568
569
570
571
572
573

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            tokenizer = self.get_tokenizer()
            model_config = self.llm_engine.get_model_config()

574
575
576
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
577
578
579
580
581
582
583
584
585
586
            conversation, mm_data = parse_chat_messages(
                msgs, model_config, tokenizer)

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
587
                    continue_final_message=continue_final_message,
588
589
590
591
592
593
594
595
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
596
                    continue_final_message=continue_final_message,
597
598
599
600
601
602
603
604
605
606
607
608
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

609
610
611
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

612
            prompts.append(prompt)
613

nunjunj's avatar
nunjunj committed
614
        return self.generate(
615
            prompts,
616
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
617
618
619
620
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

621
622
623
624
625
626
627
628
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
629
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
630
631
    ) -> List[EmbeddingRequestOutput]:
        ...
632

633
    @overload  # LEGACY: multi (prompt + optional token ids)
634
635
    def encode(
        self,
636
        prompts: List[str],
637
        pooling_params: Optional[Union[PoolingParams,
638
                                       Sequence[PoolingParams]]] = None,
639
640
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
641
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
642
643
644
645
646
647
648
649
650
651
652
653
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
654
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
655
656
657
658
659
660
661
662
663
664
665
666
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
667
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
668
669
670
671
672
673
674
675
676
677
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
678
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
679
680
681
682
683
684
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
685
686
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
687
688
689
690
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
691
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
692
693
694
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
695
696
697
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
698
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
699
    )
700
701
    def encode(
        self,
702
        prompts: Union[Union[PromptType, Sequence[PromptType]],
703
704
705
706
707
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
708
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
709
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
710
711
712
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

713
        This class automatically batches the given prompts, considering
714
715
716
717
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
718
719
720
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
721
722
723
724
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
725
            prompt_adapter_request: Prompt Adapter request to use for
726
                generation, if any.
727
728
729
730

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
731
732
733
734
735

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
736
        """
737
738
739
740
741
742
743
744
745
746
747
748
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.encode() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))
749

750
        if prompt_token_ids is not None:
751
            parsed_prompts = self._convert_v1_inputs(
752
753
754
755
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
756
757
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
758

759
760
761
762
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

763
        self._validate_and_add_requests(
764
            prompts=parsed_prompts,
765
766
            params=pooling_params,
            lora_request=lora_request,
767
            prompt_adapter_request=prompt_adapter_request,
768
769
        )

770
771
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
772

773
774
775
776
777
778
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

779
780
    # LEGACY
    def _convert_v1_inputs(
781
782
        self,
        prompts: Optional[Union[str, List[str]]],
783
784
785
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
786

787
788
789
790
791
792
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
793

794
        num_requests = None
795
796
        if prompts is not None:
            num_requests = len(prompts)
797
798
799
800
801
802
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

803
            num_requests = len(prompt_token_ids)
804
805
806
807
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

808
        parsed_prompts: List[PromptType] = []
809
        for i in range(num_requests):
810
            item: PromptType
811

812
            if prompts is not None:
813
814
815
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
816
            else:
817
                raise AssertionError
818

819
            parsed_prompts.append(item)
820

821
        return parsed_prompts
822
823
824

    def _validate_and_add_requests(
        self,
825
        prompts: Union[PromptType, Sequence[PromptType]],
826
827
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
828
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
829
        prompt_adapter_request: Optional[PromptAdapterRequest],
830
        guided_options: Optional[GuidedDecodingRequest] = None,
831
        priority: Optional[List[int]] = None,
832
    ) -> None:
833
834
835
836
837
838
839
840
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

841
        if isinstance(prompts, (str, dict)):
842
            # Convert a single prompt to a list.
843
            prompts = [prompts]
844

845
        num_requests = len(prompts)
846
847
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
848
                             "must be the same.")
849
850
851
852
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
853

854
855
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
856
                self._add_guided_params(sp, guided_options)
857
858
859

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
860

Zhuohan Li's avatar
Zhuohan Li committed
861
        # Add requests to the engine.
862
        for i, prompt in enumerate(prompts):
863
            self._add_request(
864
                prompt,
865
                params[i] if isinstance(params, Sequence) else params,
866
867
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
868
                prompt_adapter_request=prompt_adapter_request,
869
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
870
            )
871

872
    def _add_request(
nunjunj's avatar
nunjunj committed
873
        self,
874
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
875
        params: Union[SamplingParams, PoolingParams],
876
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
877
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
878
        priority: int = 0,
879
880
    ) -> None:
        request_id = str(next(self.request_counter))
881
882
        self.llm_engine.add_request(
            request_id,
883
            prompt,
884
885
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
886
            prompt_adapter_request=prompt_adapter_request,
887
            priority=priority,
nunjunj's avatar
nunjunj committed
888
        )
889

890
    def _add_guided_params(
891
892
893
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
909
910
        return params

911
    def _run_engine(
912
            self, *, use_tqdm: bool
913
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
914
915
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
916
            num_requests = self.llm_engine.get_num_unfinished_requests()
917
918
919
920
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
921
922
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
923
            )
924

Zhuohan Li's avatar
Zhuohan Li committed
925
        # Run the engine.
926
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
927
928
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
929
930
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
931
            for output in step_outputs:
932
                if output.finished:
933
934
                    outputs.append(output)
                    if use_tqdm:
935
936
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
937
                            assert output.prompt_token_ids is not None
938
939
940
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
941
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
942
943
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
944
945
946
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
947
                        pbar.update(1)
948

949
950
        if use_tqdm:
            pbar.close()
951
952
953
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
954
        return sorted(outputs, key=lambda x: int(x.request_id))
955
956
957

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()