llm.py 40.6 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
Joe Runde's avatar
Joe Runde committed
4
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
5
                    Union, cast, overload)
6

7
from tqdm import tqdm
8

9
from vllm import envs
10
11
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
12
from vllm.engine.arg_utils import EngineArgs, TaskOption
Joe Runde's avatar
Joe Runde committed
13
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
14
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
15
16
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
nunjunj's avatar
nunjunj committed
17
                                         parse_chat_messages)
18
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
19
from vllm.inputs.parse import parse_and_batch_prompt
20
from vllm.logger import init_logger
21
from vllm.lora.request import LoRARequest
22
23
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
24
25
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
26
from vllm.prompt_adapter.request import PromptAdapterRequest
27
28
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
29
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
30
31
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
32
from vllm.usage.usage_lib import UsageContext
33
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
34

35
36
logger = init_logger(__name__)

37
38

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
39
40
41
42
43
44
45
46
47
48
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
49
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
50
51
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
52
53
54
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
55
56
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
57
58
59
60
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
61
62
63
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
64
65
66
67
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
68
        quantization: The method used to quantize the model weights. Currently,
69
            we support "awq", "gptq", and "fp8" (experimental).
70
71
72
73
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
74
75
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
76
77
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
78
79
80
81
82
83
84
85
86
87
88
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
89
90
91
92
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
93
94
95
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
96
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
97
            When a sequence has context length larger than this, we fall back
98
99
100
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
101
102
103
104
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
        hf_overrides: Arguments to be forwarded to the HuggingFace config.
105
106
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
107

108
109
110
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
111
    """
112

113
114
115
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

116
117
118
119
120
121
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

122
123
124
125
126
127
128
129
130
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

131
132
133
134
135
136
137
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
138
139
140
    def __init__(
        self,
        model: str,
141
        tokenizer: Optional[str] = None,
142
        tokenizer_mode: str = "auto",
143
        skip_tokenizer_init: bool = False,
144
        trust_remote_code: bool = False,
145
        allowed_local_media_path: str = "",
146
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
147
        dtype: str = "auto",
148
        quantization: Optional[str] = None,
149
        revision: Optional[str] = None,
150
        tokenizer_revision: Optional[str] = None,
151
152
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
153
        swap_space: float = 4,
154
        cpu_offload_gb: float = 0,
155
        enforce_eager: Optional[bool] = None,
156
        max_seq_len_to_capture: int = 8192,
157
        disable_custom_all_reduce: bool = False,
158
        disable_async_output_proc: bool = False,
159
        hf_overrides: Optional[dict] = None,
160
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
161
162
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
163
164
165
166
167
        pooling_type: Optional[str] = None,
        pooling_norm: Optional[bool] = None,
        pooling_softmax: Optional[bool] = None,
        pooling_step_tag_id: Optional[int] = None,
        pooling_returned_token_ids: Optional[List[int]] = None,
168
169
        **kwargs,
    ) -> None:
170
171
172
173
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
174
        it defaults to False.
175
176
        '''

177
178
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
179

Zhuohan Li's avatar
Zhuohan Li committed
180
        engine_args = EngineArgs(
181
            model=model,
182
            task=task,
183
            tokenizer=tokenizer,
184
            tokenizer_mode=tokenizer_mode,
185
            skip_tokenizer_init=skip_tokenizer_init,
186
            trust_remote_code=trust_remote_code,
187
            allowed_local_media_path=allowed_local_media_path,
188
189
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
190
            quantization=quantization,
191
            revision=revision,
192
            tokenizer_revision=tokenizer_revision,
193
194
195
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
196
            cpu_offload_gb=cpu_offload_gb,
197
            enforce_eager=enforce_eager,
198
            max_seq_len_to_capture=max_seq_len_to_capture,
199
            disable_custom_all_reduce=disable_custom_all_reduce,
200
            disable_async_output_proc=disable_async_output_proc,
201
            hf_overrides=hf_overrides,
202
            mm_processor_kwargs=mm_processor_kwargs,
203
204
205
206
207
            pooling_type=pooling_type,
            pooling_norm=pooling_norm,
            pooling_softmax=pooling_softmax,
            pooling_step_tag_id=pooling_step_tag_id,
            pooling_returned_token_ids=pooling_returned_token_ids,
208
209
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
210
211
212
213
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
214
            engine_args, usage_context=UsageContext.LLM_CLASS)
215
216
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
217
218
219
220
221
222
223
224
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

225
226
227
228
229
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
230

231
232
233
234
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
235
            tokenizer_group.tokenizer = tokenizer
236
        else:
237
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
238

239
240
241
242
243
244
245
246
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
247
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
248
249
250
251
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
252
253
    def generate(
        self,
254
        prompts: List[str],
255
256
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
257
        prompt_token_ids: Optional[List[List[int]]] = None,
258
        use_tqdm: bool = True,
259
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
260
261
262
263
264
265
266
267
268
269
270
271
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
272
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
273
274
275
276
277
278
279
280
281
282
283
284
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
285
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
286
287
288
289
290
291
292
293
294
295
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
296
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
297
298
299
300
301
302
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
303
304
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
305
306
307
308
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
309
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
310
311
312
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
313
314
315
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
316
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
317
    )
318
319
    def generate(
        self,
320
        prompts: Union[Union[PromptType, Sequence[PromptType]],
321
322
323
324
325
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
326
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
327
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
328
        guided_options_request: Optional[Union[LLMGuidedOptions,
329
330
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
331
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
332
333
        """Generates the completions for the input prompts.

334
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
335
336
337
338
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
339
340
341
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
342
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
343
344
345
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
346
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
347
            use_tqdm: Whether to use tqdm to display the progress bar.
348
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
349
            prompt_adapter_request: Prompt Adapter request to use for
350
                generation, if any.
351
352
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
353
354

        Returns:
nunjunj's avatar
nunjunj committed
355
            A list of ``RequestOutput`` objects containing the
356
            generated completions in the same order as the input prompts.
357
358
359
360
361

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
362
        """
363
364
365
        task = self.llm_engine.model_config.task
        if task != "generate":
            messages = [
366
                "LLM.generate() is only supported for (conditional) generation "
367
368
369
370
371
372
373
374
375
376
377
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "generate" in supported_tasks:
                messages.append(
                    "Your model supports the 'generate' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task generate`.")

            raise ValueError(" ".join(messages))
378

379
        if prompt_token_ids is not None:
380
            parsed_prompts = self._convert_v1_inputs(
381
382
383
384
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
385
386
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
387

388
389
390
391
392
393
394
395
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

396
397
398
399
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

400
        self._validate_and_add_requests(
401
            prompts=parsed_prompts,
402
403
            params=sampling_params,
            lora_request=lora_request,
404
            prompt_adapter_request=prompt_adapter_request,
405
406
            guided_options=guided_options_request,
            priority=priority)
407

408
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
409
        return self.engine_class.validate_outputs(outputs, RequestOutput)
410

411
412
413
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
414
        params: BeamSearchParams,
415
416
417
418
419
420
421
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
422
423
            params: The beam search parameters.

424
425
426
427
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

428
429
430
431
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
432
433
434
435
436
437
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
438

439
440
441
442
443
444
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
445
                                            temperature=temperature)
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
491
                                logprobs=current_beam.logprobs + [logprobs],
492
493
494
495
496
497
498
499
500
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
501
                                      key=sort_beams_key,
502
503
504
505
506
507
508
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
509
                                      key=sort_beams_key,
510
511
512
513
514
515
516
517
518
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
519
520
    def chat(
        self,
521
522
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
523
524
525
526
527
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
528
        add_generation_prompt: bool = True,
529
        continue_final_message: bool = False,
530
        tools: Optional[List[Dict[str, Any]]] = None,
531
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
532
533
    ) -> List[RequestOutput]:
        """
534
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
535

536
537
538
539
540
541
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
542
543

        Args:
544
545
546
            messages: A list of conversations or a single conversation. 
                - Each conversation is represented as a list of messages.
                - Each message is a dictionary with 'role' and 'content' keys.
nunjunj's avatar
nunjunj committed
547
548
549
550
551
552
553
554
555
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
556
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
557
                to each message.
558
559
560
            continue_final_message: If True, continues the final message in
                the conversation instead of starting a new one. Cannot be `True`
                if `add_generation_prompt` is also `True`.
561
562
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
563
564
565
566
567

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
568
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
569

570
571
572
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
573
574
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
575
        else:
576
            # messages is List[...]
577
578
579
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
580
581
582
583
584
585
586

        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
            tokenizer = self.get_tokenizer()
            model_config = self.llm_engine.get_model_config()

587
588
589
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
590
591
592
593
594
595
596
597
598
599
            conversation, mm_data = parse_chat_messages(
                msgs, model_config, tokenizer)

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
600
                    continue_final_message=continue_final_message,
601
602
603
604
605
606
607
608
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
609
                    continue_final_message=continue_final_message,
610
611
612
613
614
615
616
617
618
619
620
621
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

622
623
624
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

625
            prompts.append(prompt)
626

nunjunj's avatar
nunjunj committed
627
        return self.generate(
628
            prompts,
629
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
630
631
632
633
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

634
635
636
637
638
639
640
641
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
642
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
643
644
    ) -> List[EmbeddingRequestOutput]:
        ...
645

646
    @overload  # LEGACY: multi (prompt + optional token ids)
647
648
    def encode(
        self,
649
        prompts: List[str],
650
        pooling_params: Optional[Union[PoolingParams,
651
                                       Sequence[PoolingParams]]] = None,
652
653
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
654
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
655
656
657
658
659
660
661
662
663
664
665
666
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
667
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
668
669
670
671
672
673
674
675
676
677
678
679
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
680
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
681
682
683
684
685
686
687
688
689
690
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
691
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
692
693
694
695
696
697
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
698
699
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
700
701
702
703
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
704
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
705
706
707
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
708
709
710
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
711
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
712
    )
713
714
    def encode(
        self,
715
        prompts: Union[Union[PromptType, Sequence[PromptType]],
716
717
718
719
720
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
721
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
722
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
723
724
725
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

726
        This class automatically batches the given prompts, considering
727
728
729
730
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
731
732
733
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
734
735
736
737
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
738
            prompt_adapter_request: Prompt Adapter request to use for
739
                generation, if any.
740
741
742
743

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
744
745
746
747
748

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
749
        """
750
751
752
753
754
755
756
757
758
759
760
761
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.encode() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))
762

763
        if prompt_token_ids is not None:
764
            parsed_prompts = self._convert_v1_inputs(
765
766
767
768
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
769
770
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
771

772
773
774
775
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

776
        self._validate_and_add_requests(
777
            prompts=parsed_prompts,
778
779
            params=pooling_params,
            lora_request=lora_request,
780
            prompt_adapter_request=prompt_adapter_request,
781
782
        )

783
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
784
785
        return self.engine_class.validate_outputs(outputs,
                                                  EmbeddingRequestOutput)
786

787
788
789
790
791
792
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

793
794
    # LEGACY
    def _convert_v1_inputs(
795
796
        self,
        prompts: Optional[Union[str, List[str]]],
797
798
799
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
800

801
802
803
804
805
806
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
807

808
        num_requests = None
809
810
        if prompts is not None:
            num_requests = len(prompts)
811
812
813
814
815
816
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

817
            num_requests = len(prompt_token_ids)
818
819
820
821
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

822
        parsed_prompts: List[PromptType] = []
823
        for i in range(num_requests):
824
            item: PromptType
825

826
            if prompts is not None:
827
828
829
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
830
            else:
831
                raise AssertionError
832

833
            parsed_prompts.append(item)
834

835
        return parsed_prompts
836
837
838

    def _validate_and_add_requests(
        self,
839
        prompts: Union[PromptType, Sequence[PromptType]],
840
841
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
842
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
843
        prompt_adapter_request: Optional[PromptAdapterRequest],
844
        guided_options: Optional[GuidedDecodingRequest] = None,
845
        priority: Optional[List[int]] = None,
846
    ) -> None:
847
848
849
850
851
852
853
854
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

855
        if isinstance(prompts, (str, dict)):
856
            # Convert a single prompt to a list.
857
            prompts = [prompts]
858

859
        num_requests = len(prompts)
860
861
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
862
                             "must be the same.")
863
864
865
866
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
867

868
869
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
870
                self._add_guided_params(sp, guided_options)
871
872
873

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
874

Zhuohan Li's avatar
Zhuohan Li committed
875
        # Add requests to the engine.
876
        for i, prompt in enumerate(prompts):
877
            self._add_request(
878
                prompt,
879
                params[i] if isinstance(params, Sequence) else params,
880
881
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
882
                prompt_adapter_request=prompt_adapter_request,
883
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
884
            )
885

886
    def _add_request(
nunjunj's avatar
nunjunj committed
887
        self,
888
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
889
        params: Union[SamplingParams, PoolingParams],
890
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
891
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
892
        priority: int = 0,
893
894
    ) -> None:
        request_id = str(next(self.request_counter))
895
896
        self.llm_engine.add_request(
            request_id,
897
            prompt,
898
899
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
900
            prompt_adapter_request=prompt_adapter_request,
901
            priority=priority,
nunjunj's avatar
nunjunj committed
902
        )
903

904
    def _add_guided_params(
905
906
907
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
923
924
        return params

925
    def _run_engine(
926
            self, *, use_tqdm: bool
927
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
928
929
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
930
            num_requests = self.llm_engine.get_num_unfinished_requests()
931
932
933
934
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
935
936
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
937
            )
938

Zhuohan Li's avatar
Zhuohan Li committed
939
        # Run the engine.
940
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
941
942
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
943
944
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
945
            for output in step_outputs:
946
                if output.finished:
947
948
                    outputs.append(output)
                    if use_tqdm:
949
950
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
951
                            assert output.prompt_token_ids is not None
952
953
954
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
955
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
956
957
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
958
959
960
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
961
                        pbar.update(1)
962

963
964
        if use_tqdm:
            pbar.close()
965
966
967
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
968
        return sorted(outputs, key=lambda x: int(x.request_id))
969
970
971

    def _is_encoder_decoder_model(self):
        return self.llm_engine.is_encoder_decoder_model()