llm.py 47.2 KB
Newer Older
1
import itertools
2
import json
3
import warnings
4
from contextlib import contextmanager
Joe Runde's avatar
Joe Runde committed
5
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
6
                    Union, cast, overload)
7

8
from tqdm import tqdm
9

10
from vllm import envs
11
12
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
13
from vllm.config import CompilationConfig
14
15
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
16
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
17
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
18
                                         ChatTemplateContentFormatOption,
19
20
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
21
22
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
23
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
24
from vllm.inputs.parse import parse_and_batch_prompt
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
29
30
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
31
from vllm.prompt_adapter.request import PromptAdapterRequest
32
33
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
34
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
35
36
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
37
from vllm.usage.usage_lib import UsageContext
38
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
39

40
41
logger = init_logger(__name__)

42
43

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
44
45
46
47
48
49
50
51
52
53
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
54
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
55
56
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
57
58
59
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
60
61
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
62
63
64
65
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
66
67
68
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
69
70
71
72
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
73
        quantization: The method used to quantize the model weights. Currently,
74
            we support "awq", "gptq", and "fp8" (experimental).
75
76
77
78
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
79
80
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
81
82
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
83
84
85
86
87
88
89
90
91
92
93
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
94
95
96
97
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
98
99
100
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
101
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
102
            When a sequence has context length larger than this, we fall back
103
104
105
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
106
107
108
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
109
110
111
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
112
113
114
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
115
116
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
117

118
119
120
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
121
    """
122

123
124
125
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

126
127
128
129
130
131
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

132
133
134
135
136
137
138
139
140
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

141
142
143
144
145
146
147
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
148
149
150
    def __init__(
        self,
        model: str,
151
        tokenizer: Optional[str] = None,
152
        tokenizer_mode: str = "auto",
153
        skip_tokenizer_init: bool = False,
154
        trust_remote_code: bool = False,
155
        allowed_local_media_path: str = "",
156
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
157
        dtype: str = "auto",
158
        quantization: Optional[str] = None,
159
        revision: Optional[str] = None,
160
        tokenizer_revision: Optional[str] = None,
161
162
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
163
        swap_space: float = 4,
164
        cpu_offload_gb: float = 0,
165
        enforce_eager: Optional[bool] = None,
166
        max_seq_len_to_capture: int = 8192,
167
        disable_custom_all_reduce: bool = False,
168
        disable_async_output_proc: bool = False,
169
        hf_overrides: Optional[HfOverrides] = None,
170
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
171
172
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
173
        override_pooler_config: Optional[PoolerConfig] = None,
174
        compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
175
176
        **kwargs,
    ) -> None:
177
178
179
180
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
181
        it defaults to False.
182
183
        '''

184
185
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
186

187
188
189
190
191
192
        if compilation_config is not None:
            compilation_config_instance = CompilationConfig.from_cli(
                json.dumps(compilation_config))
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
193
        engine_args = EngineArgs(
194
            model=model,
195
            task=task,
196
            tokenizer=tokenizer,
197
            tokenizer_mode=tokenizer_mode,
198
            skip_tokenizer_init=skip_tokenizer_init,
199
            trust_remote_code=trust_remote_code,
200
            allowed_local_media_path=allowed_local_media_path,
201
202
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
203
            quantization=quantization,
204
            revision=revision,
205
            tokenizer_revision=tokenizer_revision,
206
207
208
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
209
            cpu_offload_gb=cpu_offload_gb,
210
            enforce_eager=enforce_eager,
211
            max_seq_len_to_capture=max_seq_len_to_capture,
212
            disable_custom_all_reduce=disable_custom_all_reduce,
213
            disable_async_output_proc=disable_async_output_proc,
214
            hf_overrides=hf_overrides,
215
            mm_processor_kwargs=mm_processor_kwargs,
216
            override_pooler_config=override_pooler_config,
217
            compilation_config=compilation_config_instance,
218
219
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
220
221
222
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
223
224

        # TODO(rob): enable mp by default (issue with fork vs spawn)
Joe Runde's avatar
Joe Runde committed
225
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
226
            engine_args, usage_context=UsageContext.LLM_CLASS)
227

228
229
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
230
231
232
233
234
235
236
237
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

238
239
240
241
242
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
243

244
245
246
247
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
248
            tokenizer_group.tokenizer = tokenizer
249
        else:
250
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
251

252
253
254
255
256
257
258
259
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
260
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
261
262
263
264
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
265
266
    def generate(
        self,
267
        prompts: List[str],
268
269
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
270
        prompt_token_ids: Optional[List[List[int]]] = None,
271
        use_tqdm: bool = True,
272
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
273
274
275
276
277
278
279
280
281
282
283
284
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
285
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
286
287
288
289
290
291
292
293
294
295
296
297
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
298
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
299
300
301
302
303
304
305
306
307
308
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
309
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
310
311
312
313
314
315
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
316
317
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
318
319
320
321
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
322
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
323
324
325
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
326
327
328
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
329
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
330
    )
331
332
    def generate(
        self,
333
        prompts: Union[Union[PromptType, Sequence[PromptType]],
334
335
336
337
338
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
339
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
340
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
341
        guided_options_request: Optional[Union[LLMGuidedOptions,
342
343
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
344
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
345
346
        """Generates the completions for the input prompts.

347
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
348
349
350
351
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
352
353
354
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
355
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
356
357
358
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
359
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
360
            use_tqdm: Whether to use tqdm to display the progress bar.
361
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
362
            prompt_adapter_request: Prompt Adapter request to use for
363
                generation, if any.
364
365
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
366
367

        Returns:
nunjunj's avatar
nunjunj committed
368
            A list of ``RequestOutput`` objects containing the
369
            generated completions in the same order as the input prompts.
370
371
372
373
374

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
375
        """
376
377
378
        task = self.llm_engine.model_config.task
        if task != "generate":
            messages = [
379
                "LLM.generate() is only supported for (conditional) generation "
380
381
382
383
384
385
386
387
388
389
390
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "generate" in supported_tasks:
                messages.append(
                    "Your model supports the 'generate' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task generate`.")

            raise ValueError(" ".join(messages))
391

392
        if prompt_token_ids is not None:
393
            parsed_prompts = self._convert_v1_inputs(
394
395
396
397
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
398
399
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
400

401
402
403
404
405
406
407
408
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

409
410
411
412
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

413
        self._validate_and_add_requests(
414
            prompts=parsed_prompts,
415
416
            params=sampling_params,
            lora_request=lora_request,
417
            prompt_adapter_request=prompt_adapter_request,
418
419
            guided_options=guided_options_request,
            priority=priority)
420

421
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
422
        return self.engine_class.validate_outputs(outputs, RequestOutput)
423

424
425
426
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
427
        params: BeamSearchParams,
428
429
430
431
432
433
434
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
435
436
            params: The beam search parameters.

437
438
439
440
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

441
442
443
444
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
445
446
447
448
449
450
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
451

452
453
454
455
456
457
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
458
                                            temperature=temperature)
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
504
                                logprobs=current_beam.logprobs + [logprobs],
505
506
507
508
509
510
511
512
513
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
514
                                      key=sort_beams_key,
515
516
517
518
519
520
521
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
522
                                      key=sort_beams_key,
523
524
525
526
527
528
529
530
531
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
532
533
    def chat(
        self,
534
535
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
536
537
538
539
540
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
541
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
542
        add_generation_prompt: bool = True,
543
        continue_final_message: bool = False,
544
        tools: Optional[List[Dict[str, Any]]] = None,
545
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
546
547
    ) -> List[RequestOutput]:
        """
548
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
549

550
551
552
553
554
555
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
556
557

        Args:
558
559
560
561
562
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
563
564
565
566
567
568
569
570
571
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
572
573
574
575
576
577
578
579
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

580
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
581
                to each message.
582
            continue_final_message: If True, continues the final message in
583
584
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
585
586
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
587
588
589
590
591

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
592
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
593

594
595
596
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
597
598
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
599
        else:
600
            # messages is List[...]
601
602
603
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
604

605
606
607
608
609
610
611
612
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

613
614
615
        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
616
617
618
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
619
            conversation, mm_data = parse_chat_messages(
620
621
622
623
624
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
625
626
627
628
629
630
631
632

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
633
                    continue_final_message=continue_final_message,
634
635
636
637
638
639
640
641
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
642
                    continue_final_message=continue_final_message,
643
644
645
646
647
648
649
650
651
652
653
654
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

655
656
657
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

658
            prompts.append(prompt)
659

nunjunj's avatar
nunjunj committed
660
        return self.generate(
661
            prompts,
662
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
663
664
665
666
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

667
668
669
670
671
672
673
674
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
675
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
676
677
    ) -> List[EmbeddingRequestOutput]:
        ...
678

679
    @overload  # LEGACY: multi (prompt + optional token ids)
680
681
    def encode(
        self,
682
        prompts: List[str],
683
        pooling_params: Optional[Union[PoolingParams,
684
                                       Sequence[PoolingParams]]] = None,
685
686
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
687
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
688
689
690
691
692
693
694
695
696
697
698
699
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
700
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
701
702
703
704
705
706
707
708
709
710
711
712
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
713
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
714
715
716
717
718
719
720
721
722
723
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
724
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
725
726
727
728
729
730
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
731
732
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
733
734
735
736
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
737
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
738
739
740
    ) -> List[EmbeddingRequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
741
742
743
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
744
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
745
    )
746
747
    def encode(
        self,
748
        prompts: Union[Union[PromptType, Sequence[PromptType]],
749
750
751
752
753
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
754
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
755
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
756
757
758
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

759
        This class automatically batches the given prompts, considering
760
761
762
763
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
764
765
766
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
767
768
769
770
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
771
            prompt_adapter_request: Prompt Adapter request to use for
772
                generation, if any.
773
774

        Returns:
775
            A list of ``EmbeddingRequestOutput`` objects containing the
776
            generated embeddings in the same order as the input prompts.
777
778
779
780
781

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
782
        """
783
784
785
786
787
788
789
790
791
792
793
794
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.encode() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))
795

796
        if prompt_token_ids is not None:
797
            parsed_prompts = self._convert_v1_inputs(
798
799
800
801
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
802
803
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
804

805
806
807
808
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

809
        self._validate_and_add_requests(
810
            prompts=parsed_prompts,
811
812
            params=pooling_params,
            lora_request=lora_request,
813
            prompt_adapter_request=prompt_adapter_request,
814
815
        )

816
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
817
818
        return self.engine_class.validate_outputs(outputs,
                                                  EmbeddingRequestOutput)
819

820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[EmbeddingRequestOutput]:
        """Generates similarity scores for all pairs <text,text_pair>.

        The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case
        the text_1 sentence will be replicated N times to pair with the text_2
        sentences. The input pairs are used to build a list of prompts for the
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
                case it has to have the same length as the text_2 list
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            generated scores in the same order as the input prompts.
        """
        task = self.llm_engine.model_config.task
        if task != "embedding":
            messages = ["LLM.score() is only supported for embedding models."]

            supported_tasks = self.llm_engine.model_config.supported_tasks
            if "embedding" in supported_tasks:
                messages.append(
                    "Your model supports the 'embedding' task, but is "
                    f"currently initialized for the '{task}' task. Please "
                    "initialize the model using `--task embedding`.")

            raise ValueError(" ".join(messages))

        if not self.llm_engine.model_config.is_cross_encoder:
            raise ValueError("Your model does not support the cross encoding")

        tokenizer = self.llm_engine.get_tokenizer()

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "MistralTokenizer not supported for cross-encoding")

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
                                     "supported for cross encoding")
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
        text_1 = [ensure_str(t) for t in text_1]

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
        text_2 = [ensure_str(t) for t in text_2]

        if len(text_1) > 1 and len(text_1) != len(text_2):
            raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
        if len(text_1) == 0:
            raise ValueError("At least one text element must be given")
        if len(text_2) == 0:
            raise ValueError("At least one text_pair element must be given")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
        pooling_params = PoolingParams()

        tokenization_kwargs: Dict[str, Any] = {}
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
        return self.engine_class.validate_outputs(outputs,
                                                  EmbeddingRequestOutput)

942
943
944
945
946
947
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

948
949
    # LEGACY
    def _convert_v1_inputs(
950
951
        self,
        prompts: Optional[Union[str, List[str]]],
952
953
954
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
955

956
957
958
959
960
961
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
962

963
        num_requests = None
964
965
        if prompts is not None:
            num_requests = len(prompts)
966
967
968
969
970
971
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

972
            num_requests = len(prompt_token_ids)
973
974
975
976
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

977
        parsed_prompts: List[PromptType] = []
978
        for i in range(num_requests):
979
            item: PromptType
980

981
            if prompts is not None:
982
983
984
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
985
            else:
986
                raise AssertionError
987

988
            parsed_prompts.append(item)
989

990
        return parsed_prompts
991
992
993

    def _validate_and_add_requests(
        self,
994
        prompts: Union[PromptType, Sequence[PromptType]],
995
996
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
997
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
998
        prompt_adapter_request: Optional[PromptAdapterRequest],
999
        guided_options: Optional[GuidedDecodingRequest] = None,
1000
        priority: Optional[List[int]] = None,
1001
    ) -> None:
1002
1003
1004
1005
1006
1007
1008
1009
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1010
        if isinstance(prompts, (str, dict)):
1011
            # Convert a single prompt to a list.
1012
            prompts = [prompts]
1013

1014
        num_requests = len(prompts)
1015
1016
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1017
                             "must be the same.")
1018
1019
1020
1021
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1022

1023
1024
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1025
                self._add_guided_params(sp, guided_options)
1026
1027
1028

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1029

Zhuohan Li's avatar
Zhuohan Li committed
1030
        # Add requests to the engine.
1031
        for i, prompt in enumerate(prompts):
1032
            self._add_request(
1033
                prompt,
1034
                params[i] if isinstance(params, Sequence) else params,
1035
1036
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1037
                prompt_adapter_request=prompt_adapter_request,
1038
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1039
            )
1040

1041
    def _add_request(
nunjunj's avatar
nunjunj committed
1042
        self,
1043
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1044
        params: Union[SamplingParams, PoolingParams],
1045
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1046
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1047
        priority: int = 0,
1048
1049
    ) -> None:
        request_id = str(next(self.request_counter))
1050
1051
        self.llm_engine.add_request(
            request_id,
1052
            prompt,
1053
1054
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1055
            prompt_adapter_request=prompt_adapter_request,
1056
            priority=priority,
nunjunj's avatar
nunjunj committed
1057
        )
1058

1059
    def _add_guided_params(
1060
1061
1062
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1078
1079
        return params

1080
    def _run_engine(
1081
            self, *, use_tqdm: bool
1082
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
1083
1084
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1085
            num_requests = self.llm_engine.get_num_unfinished_requests()
1086
1087
1088
1089
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1090
1091
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1092
            )
1093

Zhuohan Li's avatar
Zhuohan Li committed
1094
        # Run the engine.
1095
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
1096
1097
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1098
1099
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1100
            for output in step_outputs:
1101
                if output.finished:
1102
1103
                    outputs.append(output)
                    if use_tqdm:
1104
1105
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1106
                            assert output.prompt_token_ids is not None
1107
1108
1109
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1110
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1111
1112
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1113
1114
1115
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1116
                        pbar.update(1)
1117

1118
1119
        if use_tqdm:
            pbar.close()
1120
1121
1122
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1123
        return sorted(outputs, key=lambda x: int(x.request_id))