"vllm/vscode:/vscode.git/clone" did not exist on "5b5f350d67a1e1efb7dbe8b18fe2353ad94857a1"
llm.py 47.5 KB
Newer Older
1
import itertools
2
import json
3
import warnings
4
from contextlib import contextmanager
Joe Runde's avatar
Joe Runde committed
5
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
6
                    Union, cast, overload)
7

8
from tqdm import tqdm
9

10
from vllm import envs
11
12
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
13
from vllm.config import CompilationConfig
14
15
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
16
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
17
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
18
                                         ChatTemplateContentFormatOption,
19
20
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
21
22
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
23
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
24
from vllm.inputs.parse import parse_and_batch_prompt
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
29
from vllm.outputs import PoolingRequestOutput, RequestOutput
30
from vllm.pooling_params import PoolingParams
31
from vllm.prompt_adapter.request import PromptAdapterRequest
32
33
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
34
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
35
36
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
37
from vllm.usage.usage_lib import UsageContext
38
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
39

40
41
logger = init_logger(__name__)

42
43

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
47
48
49
50
51
52
53
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
54
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
55
56
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
57
58
59
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
60
61
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
62
63
64
65
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
72
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
73
        quantization: The method used to quantize the model weights. Currently,
74
            we support "awq", "gptq", and "fp8" (experimental).
75
76
77
78
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
79
80
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
81
82
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
83
84
85
86
87
88
89
90
91
92
93
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
94
95
96
97
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
98
99
100
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
101
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
102
            When a sequence has context length larger than this, we fall back
103
104
105
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
106
107
108
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
109
110
111
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
112
113
114
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
115
116
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
117

118
119
120
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
121
    """
122

123
124
125
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

126
127
128
129
130
131
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

132
133
134
135
136
137
138
139
140
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

141
142
143
144
145
146
147
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
148
149
150
    def __init__(
        self,
        model: str,
151
        tokenizer: Optional[str] = None,
152
        tokenizer_mode: str = "auto",
153
        skip_tokenizer_init: bool = False,
154
        trust_remote_code: bool = False,
155
        allowed_local_media_path: str = "",
156
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        dtype: str = "auto",
158
        quantization: Optional[str] = None,
159
        revision: Optional[str] = None,
160
        tokenizer_revision: Optional[str] = None,
161
162
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
163
        swap_space: float = 4,
164
        cpu_offload_gb: float = 0,
165
        enforce_eager: Optional[bool] = None,
166
        max_seq_len_to_capture: int = 8192,
167
        disable_custom_all_reduce: bool = False,
168
        disable_async_output_proc: bool = False,
169
        hf_overrides: Optional[HfOverrides] = None,
170
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
171
172
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
173
        override_pooler_config: Optional[PoolerConfig] = None,
174
        compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
175
176
        **kwargs,
    ) -> None:
177
178
179
180
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
181
        it defaults to False.
182
183
        '''

184
185
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
186

187
        if compilation_config is not None:
188
189
190
191
192
193
194
195
            if isinstance(compilation_config, (int)):
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            elif isinstance(compilation_config, (dict)):
                compilation_config_instance = CompilationConfig.from_cli(
                    json.dumps(compilation_config))
            else:
                compilation_config_instance = compilation_config
196
197
198
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
199
        engine_args = EngineArgs(
200
            model=model,
201
            task=task,
202
            tokenizer=tokenizer,
203
            tokenizer_mode=tokenizer_mode,
204
            skip_tokenizer_init=skip_tokenizer_init,
205
            trust_remote_code=trust_remote_code,
206
            allowed_local_media_path=allowed_local_media_path,
207
208
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
209
            quantization=quantization,
210
            revision=revision,
211
            tokenizer_revision=tokenizer_revision,
212
213
214
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
215
            cpu_offload_gb=cpu_offload_gb,
216
            enforce_eager=enforce_eager,
217
            max_seq_len_to_capture=max_seq_len_to_capture,
218
            disable_custom_all_reduce=disable_custom_all_reduce,
219
            disable_async_output_proc=disable_async_output_proc,
220
            hf_overrides=hf_overrides,
221
            mm_processor_kwargs=mm_processor_kwargs,
222
            override_pooler_config=override_pooler_config,
223
            compilation_config=compilation_config_instance,
224
225
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
226
227
228
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
229
230

        # TODO(rob): enable mp by default (issue with fork vs spawn)
Joe Runde's avatar
Joe Runde committed
231
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
232
            engine_args, usage_context=UsageContext.LLM_CLASS)
233

234
235
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
236
237
238
239
240
241
242
243
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

244
245
246
247
248
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
249

250
251
252
253
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
254
            tokenizer_group.tokenizer = tokenizer
255
        else:
256
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
257

258
259
260
261
262
263
264
265
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
266
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
267
268
269
270
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
271
272
    def generate(
        self,
273
        prompts: List[str],
274
275
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
276
        prompt_token_ids: Optional[List[List[int]]] = None,
277
        use_tqdm: bool = True,
278
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
279
280
281
282
283
284
285
286
287
288
289
290
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
291
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
292
293
294
295
296
297
298
299
300
301
302
303
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
304
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
305
306
307
308
309
310
311
312
313
314
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
315
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
316
317
318
319
320
321
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
322
323
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
324
325
326
327
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
328
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
329
330
331
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
332
333
334
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
335
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
336
    )
337
338
    def generate(
        self,
339
        prompts: Union[Union[PromptType, Sequence[PromptType]],
340
341
342
343
344
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
345
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
346
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
347
        guided_options_request: Optional[Union[LLMGuidedOptions,
348
349
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
350
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
351
352
        """Generates the completions for the input prompts.

353
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
354
355
356
357
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
358
359
360
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
361
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
362
363
364
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
365
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
366
            use_tqdm: Whether to use tqdm to display the progress bar.
367
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
368
            prompt_adapter_request: Prompt Adapter request to use for
369
                generation, if any.
370
371
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
372
373

        Returns:
nunjunj's avatar
nunjunj committed
374
            A list of ``RequestOutput`` objects containing the
375
            generated completions in the same order as the input prompts.
376
377
378
379
380

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
381
        """
382
383
384
        task = self.llm_engine.model_config.task
        if task != "generate":
            messages = [
385
                "LLM.generate() is only supported for (conditional) generation "
386
387
388
389
390
391
392
393
394
395
396
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "generate" in supported_tasks:
                messages.append(
                    "Your model supports the 'generate' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task generate`.")

            raise ValueError(" ".join(messages))
397

398
        if prompt_token_ids is not None:
399
            parsed_prompts = self._convert_v1_inputs(
400
401
402
403
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
404
405
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
406

407
408
409
410
411
412
413
414
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

415
416
417
418
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

419
        self._validate_and_add_requests(
420
            prompts=parsed_prompts,
421
422
            params=sampling_params,
            lora_request=lora_request,
423
            prompt_adapter_request=prompt_adapter_request,
424
425
            guided_options=guided_options_request,
            priority=priority)
426

427
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
428
        return self.engine_class.validate_outputs(outputs, RequestOutput)
429

430
431
432
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
433
        params: BeamSearchParams,
434
435
436
437
438
439
440
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
441
442
            params: The beam search parameters.

443
444
445
446
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

447
448
449
450
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
451
452
453
454
455
456
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
457

458
459
460
461
462
463
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
464
                                            temperature=temperature)
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
510
                                logprobs=current_beam.logprobs + [logprobs],
511
512
513
514
515
516
517
518
519
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
520
                                      key=sort_beams_key,
521
522
523
524
525
526
527
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
528
                                      key=sort_beams_key,
529
530
531
532
533
534
535
536
537
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
538
539
    def chat(
        self,
540
541
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
542
543
544
545
546
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
547
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
548
        add_generation_prompt: bool = True,
549
        continue_final_message: bool = False,
550
        tools: Optional[List[Dict[str, Any]]] = None,
551
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
552
553
    ) -> List[RequestOutput]:
        """
554
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
555

556
557
558
559
560
561
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
562
563

        Args:
564
565
566
567
568
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
569
570
571
572
573
574
575
576
577
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
578
579
580
581
582
583
584
585
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

586
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
587
                to each message.
588
            continue_final_message: If True, continues the final message in
589
590
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
591
592
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
593
594
595
596
597

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
598
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
599

600
601
602
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
603
604
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
605
        else:
606
            # messages is List[...]
607
608
609
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
610

611
612
613
614
615
616
617
618
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

619
620
621
        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
622
623
624
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
625
            conversation, mm_data = parse_chat_messages(
626
627
628
629
630
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
631
632
633
634
635
636
637
638

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
639
                    continue_final_message=continue_final_message,
640
641
642
643
644
645
646
647
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
648
                    continue_final_message=continue_final_message,
649
650
651
652
653
654
655
656
657
658
659
660
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

661
662
663
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

664
            prompts.append(prompt)
665

nunjunj's avatar
nunjunj committed
666
        return self.generate(
667
            prompts,
668
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
669
670
671
672
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

673
674
675
676
677
678
679
680
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
681
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
682
    ) -> List[PoolingRequestOutput]:
683
        ...
684

685
    @overload  # LEGACY: multi (prompt + optional token ids)
686
687
    def encode(
        self,
688
        prompts: List[str],
689
        pooling_params: Optional[Union[PoolingParams,
690
                                       Sequence[PoolingParams]]] = None,
691
692
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
693
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
694
    ) -> List[PoolingRequestOutput]:
695
696
697
698
699
700
701
702
703
704
705
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
706
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
707
    ) -> List[PoolingRequestOutput]:
708
709
710
711
712
713
714
715
716
717
718
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
719
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
720
    ) -> List[PoolingRequestOutput]:
721
722
723
724
725
726
727
728
729
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
730
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
731
    ) -> List[PoolingRequestOutput]:
732
733
734
735
736
        ...

    @overload
    def encode(
        self,
737
738
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
739
740
741
742
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
743
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
744
    ) -> List[PoolingRequestOutput]:
745
746
        ...

nunjunj's avatar
nunjunj committed
747
748
749
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
750
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
751
    )
752
753
    def encode(
        self,
754
        prompts: Union[Union[PromptType, Sequence[PromptType]],
755
756
757
758
759
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
760
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
761
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
762
    ) -> List[PoolingRequestOutput]:
763
764
        """Generates the completions for the input prompts.

765
        This class automatically batches the given prompts, considering
766
767
768
769
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
770
771
772
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
773
774
775
776
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
777
            prompt_adapter_request: Prompt Adapter request to use for
778
                generation, if any.
779
780

        Returns:
781
            A list of ``PoolingRequestOutput`` objects containing the
782
            generated embeddings in the same order as the input prompts.
783
784
785
786
787

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
788
        """
789
790
791
792
793
794
795
796
797
798
799
800
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.encode() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))
801

802
        if prompt_token_ids is not None:
803
            parsed_prompts = self._convert_v1_inputs(
804
805
806
807
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
808
809
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
810

811
812
813
814
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

815
        self._validate_and_add_requests(
816
            prompts=parsed_prompts,
817
818
            params=pooling_params,
            lora_request=lora_request,
819
            prompt_adapter_request=prompt_adapter_request,
820
821
        )

822
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
823
        return self.engine_class.validate_outputs(outputs,
824
                                                  PoolingRequestOutput)
825

826
827
828
829
830
831
832
833
834
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
835
    ) -> List[PoolingRequestOutput]:
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
        """Generates similarity scores for all pairs <text,text_pair>.

        The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
        the text_1 sentence will be replicated N times to pair with the text_2
        sentences. The input pairs are used to build a list of prompts for the
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
                case it has to have the same length as the text_2 list
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
857
            A list of ``PoolingRequestOutput`` objects containing the
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
            generated scores in the same order as the input prompts.
        """
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.score() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))

        if not self.llm_engine.model_config.is_cross_encoder:
            raise ValueError("Your model does not support the cross encoding")

        tokenizer = self.llm_engine.get_tokenizer()

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "MistralTokenizer not supported for cross-encoding")

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
                                     "supported for cross encoding")
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
        text_1 = [ensure_str(t) for t in text_1]

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
        text_2 = [ensure_str(t) for t in text_2]

        if len(text_1) > 1 and len(text_1) != len(text_2):
            raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
        if len(text_1) == 0:
            raise ValueError("At least one text element must be given")
        if len(text_2) == 0:
            raise ValueError("At least one text_pair element must be given")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
        pooling_params = PoolingParams()

        tokenization_kwargs: Dict[str, Any] = {}
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        return self.engine_class.validate_outputs(outputs,
946
                                                  PoolingRequestOutput)
947

948
949
950
951
952
953
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

954
955
    # LEGACY
    def _convert_v1_inputs(
956
957
        self,
        prompts: Optional[Union[str, List[str]]],
958
959
960
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
961

962
963
964
965
966
967
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
968

969
        num_requests = None
970
971
        if prompts is not None:
            num_requests = len(prompts)
972
973
974
975
976
977
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

978
            num_requests = len(prompt_token_ids)
979
980
981
982
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

983
        parsed_prompts: List[PromptType] = []
984
        for i in range(num_requests):
985
            item: PromptType
986

987
            if prompts is not None:
988
989
990
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
991
            else:
992
                raise AssertionError
993

994
            parsed_prompts.append(item)
995

996
        return parsed_prompts
997
998
999

    def _validate_and_add_requests(
        self,
1000
        prompts: Union[PromptType, Sequence[PromptType]],
1001
1002
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1003
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1004
        prompt_adapter_request: Optional[PromptAdapterRequest],
1005
        guided_options: Optional[GuidedDecodingRequest] = None,
1006
        priority: Optional[List[int]] = None,
1007
    ) -> None:
1008
1009
1010
1011
1012
1013
1014
1015
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1016
        if isinstance(prompts, (str, dict)):
1017
            # Convert a single prompt to a list.
1018
            prompts = [prompts]
1019

1020
        num_requests = len(prompts)
1021
1022
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1023
                             "must be the same.")
1024
1025
1026
1027
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1028

1029
1030
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1031
                self._add_guided_params(sp, guided_options)
1032
1033
1034

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1035

Zhuohan Li's avatar
Zhuohan Li committed
1036
        # Add requests to the engine.
1037
        for i, prompt in enumerate(prompts):
1038
            self._add_request(
1039
                prompt,
1040
                params[i] if isinstance(params, Sequence) else params,
1041
1042
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1043
                prompt_adapter_request=prompt_adapter_request,
1044
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1045
            )
1046

1047
    def _add_request(
nunjunj's avatar
nunjunj committed
1048
        self,
1049
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1050
        params: Union[SamplingParams, PoolingParams],
1051
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1052
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1053
        priority: int = 0,
1054
1055
    ) -> None:
        request_id = str(next(self.request_counter))
1056
1057
        self.llm_engine.add_request(
            request_id,
1058
            prompt,
1059
1060
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1061
            prompt_adapter_request=prompt_adapter_request,
1062
            priority=priority,
nunjunj's avatar
nunjunj committed
1063
        )
1064

1065
    def _add_guided_params(
1066
1067
1068
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1084
1085
        return params

1086
    def _run_engine(
1087
            self, *, use_tqdm: bool
1088
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1089
1090
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1091
            num_requests = self.llm_engine.get_num_unfinished_requests()
1092
1093
1094
1095
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1096
1097
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1098
            )
1099

Zhuohan Li's avatar
Zhuohan Li committed
1100
        # Run the engine.
1101
        outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
1102
1103
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1104
1105
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1106
            for output in step_outputs:
1107
                if output.finished:
1108
1109
                    outputs.append(output)
                    if use_tqdm:
1110
1111
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1112
                            assert output.prompt_token_ids is not None
1113
1114
1115
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1116
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1117
1118
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1119
1120
1121
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1122
                        pbar.update(1)
1123

1124
1125
        if use_tqdm:
            pbar.close()
1126
1127
1128
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1129
        return sorted(outputs, key=lambda x: int(x.request_id))