llm.py 53.9 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
Joe Runde's avatar
Joe Runde committed
4
from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type,
5
                    Union, cast, overload)
6

7
from tqdm import tqdm
8
from typing_extensions import deprecated
9

10
from vllm import envs
11
12
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
13
from vllm.config import CompilationConfig
14
15
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
16
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
17
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
18
                                         ChatTemplateContentFormatOption,
19
20
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
21
22
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
23
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
24
from vllm.inputs.parse import parse_and_batch_prompt
25
from vllm.logger import init_logger
26
from vllm.lora.request import LoRARequest
27
28
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
29
30
31
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
32
from vllm.pooling_params import PoolingParams
33
from vllm.prompt_adapter.request import PromptAdapterRequest
34
35
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
36
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
37
38
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
39
from vllm.usage.usage_lib import UsageContext
40
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
41

42
43
logger = init_logger(__name__)

44
45

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
50
51
52
53
54
55
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
56
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
57
58
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
59
60
61
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
62
63
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
64
65
66
67
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
68
69
70
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
71
72
73
74
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
75
        quantization: The method used to quantize the model weights. Currently,
76
            we support "awq", "gptq", and "fp8" (experimental).
77
78
79
80
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
81
82
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
83
84
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
85
86
87
88
89
90
91
92
93
94
95
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
96
97
98
99
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
100
101
102
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
103
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
104
            When a sequence has context length larger than this, we fall back
105
106
107
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
108
109
110
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
111
112
113
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
114
115
116
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
117
118
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
nunjunj's avatar
nunjunj committed
119

120
121
122
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
123
    """
124

125
    DEPRECATE_LEGACY: ClassVar[bool] = True
126
127
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

128
129
130
131
132
133
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

134
135
136
137
138
139
140
141
142
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

143
144
145
146
147
148
149
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
150
151
152
    def __init__(
        self,
        model: str,
153
        tokenizer: Optional[str] = None,
154
        tokenizer_mode: str = "auto",
155
        skip_tokenizer_init: bool = False,
156
        trust_remote_code: bool = False,
157
        allowed_local_media_path: str = "",
158
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        dtype: str = "auto",
160
        quantization: Optional[str] = None,
161
        revision: Optional[str] = None,
162
        tokenizer_revision: Optional[str] = None,
163
164
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
165
        swap_space: float = 4,
166
        cpu_offload_gb: float = 0,
167
        enforce_eager: Optional[bool] = None,
168
        max_seq_len_to_capture: int = 8192,
169
        disable_custom_all_reduce: bool = False,
170
        disable_async_output_proc: bool = False,
171
        hf_overrides: Optional[HfOverrides] = None,
172
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
173
174
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
175
        override_pooler_config: Optional[PoolerConfig] = None,
176
        compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
177
178
        **kwargs,
    ) -> None:
179
180
181
182
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
183
        it defaults to False.
184
185
        '''

186
187
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
188

189
        if compilation_config is not None:
190
            if isinstance(compilation_config, (int, dict)):
191
192
193
194
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
195
196
197
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
198
        engine_args = EngineArgs(
199
            model=model,
200
            task=task,
201
            tokenizer=tokenizer,
202
            tokenizer_mode=tokenizer_mode,
203
            skip_tokenizer_init=skip_tokenizer_init,
204
            trust_remote_code=trust_remote_code,
205
            allowed_local_media_path=allowed_local_media_path,
206
207
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
208
            quantization=quantization,
209
            revision=revision,
210
            tokenizer_revision=tokenizer_revision,
211
212
213
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
214
            cpu_offload_gb=cpu_offload_gb,
215
            enforce_eager=enforce_eager,
216
            max_seq_len_to_capture=max_seq_len_to_capture,
217
            disable_custom_all_reduce=disable_custom_all_reduce,
218
            disable_async_output_proc=disable_async_output_proc,
219
            hf_overrides=hf_overrides,
220
            mm_processor_kwargs=mm_processor_kwargs,
221
            override_pooler_config=override_pooler_config,
222
            compilation_config=compilation_config_instance,
223
224
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
225
226
227
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
228
229

        # TODO(rob): enable mp by default (issue with fork vs spawn)
Joe Runde's avatar
Joe Runde committed
230
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
231
            engine_args, usage_context=UsageContext.LLM_CLASS)
232

233
234
        self.request_counter = Counter()

235
236
237
238
    def __del__(self):
        if self.llm_engine and hasattr(self.llm_engine, "shutdown"):
            self.llm_engine.shutdown()

Joe Runde's avatar
Joe Runde committed
239
240
241
242
243
244
245
246
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

247
248
249
250
251
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
252

253
254
255
256
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
257
            tokenizer_group.tokenizer = tokenizer
258
        else:
259
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
260

261
262
263
264
265
266
267
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
268
        *,
269
270
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
271
272
273
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
274
275
276
    ) -> List[RequestOutput]:
        ...

277
    @overload  # LEGACY: single (prompt + optional token ids)
278
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
279
280
281
282
283
284
285
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
286
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
287
288
289
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
290
291
292
293
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
294
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
295
296
    def generate(
        self,
297
        prompts: List[str],
298
299
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
300
        prompt_token_ids: Optional[List[List[int]]] = None,
301
        use_tqdm: bool = True,
302
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
303
304
305
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
306
307
308
309
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
310
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
311
312
313
314
315
316
317
318
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
319
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
320
321
322
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
323
324
325
326
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
327
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
328
329
330
331
332
333
334
335
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
336
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
337
338
339
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
340
341
342
343
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
344
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
345
346
347
348
349
350
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
351
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
352
353
354
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
355
356
357
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
358
359
360
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
361
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
362
    )
363
364
    def generate(
        self,
365
        prompts: Union[Union[PromptType, Sequence[PromptType]],
366
367
368
369
370
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
371
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
372
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
373
        guided_options_request: Optional[Union[LLMGuidedOptions,
374
375
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
376
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
377
378
        """Generates the completions for the input prompts.

379
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
380
381
382
383
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
384
385
386
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
387
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
388
389
390
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
391
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
392
            use_tqdm: Whether to use tqdm to display the progress bar.
393
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
394
            prompt_adapter_request: Prompt Adapter request to use for
395
                generation, if any.
396
397
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
398
399

        Returns:
nunjunj's avatar
nunjunj committed
400
            A list of ``RequestOutput`` objects containing the
401
            generated completions in the same order as the input prompts.
402
403
404
405
406

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
407
        """
408
409
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "generate":
410
            messages = [
411
                "LLM.generate() is only supported for (conditional) generation "
412
413
414
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

415
416
417
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
418
                messages.append(
419
420
421
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
422
423

            raise ValueError(" ".join(messages))
424

425
        if prompt_token_ids is not None:
426
            parsed_prompts = self._convert_v1_inputs(
427
428
429
430
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
431
432
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
433

434
435
436
437
438
439
440
441
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

442
443
444
445
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

446
        self._validate_and_add_requests(
447
            prompts=parsed_prompts,
448
449
            params=sampling_params,
            lora_request=lora_request,
450
            prompt_adapter_request=prompt_adapter_request,
451
452
            guided_options=guided_options_request,
            priority=priority)
453

454
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
455
        return self.engine_class.validate_outputs(outputs, RequestOutput)
456

457
458
459
    def beam_search(
        self,
        prompts: List[Union[str, List[int]]],
460
        params: BeamSearchParams,
461
462
463
464
465
466
467
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
468
469
            params: The beam search parameters.

470
471
472
473
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

474
475
476
477
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
478
479
480
481
482
483
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
484

485
486
487
488
489
490
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
491
                                            temperature=temperature)
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
            prompt_tokens = prompt if isinstance(
                prompt, list) else tokenizer.encode(prompt)
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
537
                                logprobs=current_beam.logprobs + [logprobs],
538
539
540
541
542
543
544
545
546
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
547
                                      key=sort_beams_key,
548
549
550
551
552
553
554
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
555
                                      key=sort_beams_key,
556
557
558
559
560
561
562
563
564
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
565
566
    def chat(
        self,
567
568
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
569
570
571
572
573
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
574
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
575
        add_generation_prompt: bool = True,
576
        continue_final_message: bool = False,
577
        tools: Optional[List[Dict[str, Any]]] = None,
578
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
579
580
    ) -> List[RequestOutput]:
        """
581
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
582

583
584
585
586
587
588
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
589
590

        Args:
591
592
593
594
595
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
596
597
598
599
600
601
602
603
604
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
605
606
607
608
609
610
611
612
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

613
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
614
                to each message.
615
            continue_final_message: If True, continues the final message in
616
617
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
618
619
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
620
621
622
623
624

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
625
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
626

627
628
629
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
630
631
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
632
        else:
633
            # messages is List[...]
634
635
636
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
637

638
639
640
641
642
643
644
645
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

646
647
648
        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
649
650
651
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
652
            conversation, mm_data = parse_chat_messages(
653
654
655
656
657
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
658
659
660
661
662
663
664
665

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
666
                    continue_final_message=continue_final_message,
667
668
669
670
671
672
673
674
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
675
                    continue_final_message=continue_final_message,
676
677
678
679
680
681
682
683
684
685
686
687
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

688
689
690
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

691
            prompts.append(prompt)
692

nunjunj's avatar
nunjunj committed
693
        return self.generate(
694
            prompts,
695
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
696
697
698
699
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

700
701
702
703
704
705
706
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
707
        *,
708
709
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
710
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
711
712
713
    ) -> List[PoolingRequestOutput]:
        ...

714
    @overload  # LEGACY: single (prompt + optional token ids)
715
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
716
717
718
719
720
721
722
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
723
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
724
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
725
    ) -> List[PoolingRequestOutput]:
726
        ...
727

728
    @overload  # LEGACY: multi (prompt + optional token ids)
729
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
730
731
    def encode(
        self,
732
        prompts: List[str],
733
        pooling_params: Optional[Union[PoolingParams,
734
                                       Sequence[PoolingParams]]] = None,
735
736
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
737
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
738
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
739
    ) -> List[PoolingRequestOutput]:
740
741
742
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
743
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
744
745
746
747
748
749
750
751
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
752
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
753
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
754
    ) -> List[PoolingRequestOutput]:
755
756
757
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
758
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
759
760
761
762
763
764
765
766
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
767
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
768
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
769
    ) -> List[PoolingRequestOutput]:
770
771
772
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
773
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
774
775
776
777
778
779
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
780
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
781
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
782
    ) -> List[PoolingRequestOutput]:
783
784
        ...

nunjunj's avatar
nunjunj committed
785
786
787
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
788
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
789
    )
790
791
    def encode(
        self,
792
        prompts: Union[Union[PromptType, Sequence[PromptType]],
793
794
795
796
797
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
798
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
799
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
800
    ) -> List[PoolingRequestOutput]:
801
802
        """Apply pooling to the hidden states corresponding to the input
        prompts.
803

804
        This class automatically batches the given prompts, considering
805
806
807
808
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
809
810
811
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
812
813
814
815
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
816
            prompt_adapter_request: Prompt Adapter request to use for
817
                generation, if any.
818
819

        Returns:
820
            A list of ``PoolingRequestOutput`` objects containing the
821
            pooled hidden states in the same order as the input prompts.
822
823
824
825
826

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
827
        """
828
829
830
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
831

832
833
834
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
835
                messages.append(
836
837
838
839
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
840
841

            raise ValueError(" ".join(messages))
842

843
        if prompt_token_ids is not None:
844
            parsed_prompts = self._convert_v1_inputs(
845
846
847
848
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
849
850
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
851

852
853
854
855
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

856
        self._validate_and_add_requests(
857
            prompts=parsed_prompts,
858
859
            params=pooling_params,
            lora_request=lora_request,
860
            prompt_adapter_request=prompt_adapter_request,
861
862
        )

863
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
864
        return self.engine_class.validate_outputs(outputs,
865
                                                  PoolingRequestOutput)
866

867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[EmbeddingRequestOutput]:
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ClassificationRequestOutput]:
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

947
948
949
950
951
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
952
        *,
953
954
955
956
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
957
958
    ) -> List[ScoringRequestOutput]:
        """Generate similarity scores for all pairs ``<text,text_pair>``.
959

960
961
962
963
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
964
965
966
967
968
969
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
970
                case it has to have the same length as the ``text_2`` list
971
972
973
974
975
976
977
978
979
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
980
            A list of ``ScoringRequestOutput`` objects containing the
981
982
            generated scores in the same order as the input prompts.
        """
983
984
985
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
986

987
988
989
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
990
                messages.append(
991
992
993
994
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
995
996
997
998

            raise ValueError(" ".join(messages))

        if not self.llm_engine.model_config.is_cross_encoder:
999
            raise ValueError("Your model does not support cross encoding")
1000
1001
        if self.llm_engine.model_config.task != "score":
            raise ValueError("Score API is only enabled for `--task score`")
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071

        tokenizer = self.llm_engine.get_tokenizer()

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "MistralTokenizer not supported for cross-encoding")

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
                                     "supported for cross encoding")
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
        text_1 = [ensure_str(t) for t in text_1]

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
        text_2 = [ensure_str(t) for t in text_2]

        if len(text_1) > 1 and len(text_1) != len(text_2):
            raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
        if len(text_1) == 0:
            raise ValueError("At least one text element must be given")
        if len(text_2) == 0:
            raise ValueError("At least one text_pair element must be given")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
        pooling_params = PoolingParams()

        tokenization_kwargs: Dict[str, Any] = {}
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1072
1073
1074
1075
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]
1076

1077
1078
1079
1080
1081
1082
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1083
1084
    # LEGACY
    def _convert_v1_inputs(
1085
1086
        self,
        prompts: Optional[Union[str, List[str]]],
1087
1088
1089
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
1090

1091
1092
1093
1094
1095
1096
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1097

1098
        num_requests = None
1099
1100
        if prompts is not None:
            num_requests = len(prompts)
1101
1102
1103
1104
1105
1106
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1107
            num_requests = len(prompt_token_ids)
1108
1109
1110
1111
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1112
        parsed_prompts: List[PromptType] = []
1113
        for i in range(num_requests):
1114
            item: PromptType
1115

1116
            if prompts is not None:
1117
1118
1119
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1120
            else:
1121
                raise AssertionError
1122

1123
            parsed_prompts.append(item)
1124

1125
        return parsed_prompts
1126
1127
1128

    def _validate_and_add_requests(
        self,
1129
        prompts: Union[PromptType, Sequence[PromptType]],
1130
1131
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1132
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1133
        prompt_adapter_request: Optional[PromptAdapterRequest],
1134
        guided_options: Optional[GuidedDecodingRequest] = None,
1135
        priority: Optional[List[int]] = None,
1136
    ) -> None:
1137
1138
1139
1140
1141
1142
1143
1144
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1145
        if isinstance(prompts, (str, dict)):
1146
            # Convert a single prompt to a list.
1147
            prompts = [prompts]
1148

1149
        num_requests = len(prompts)
1150
1151
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1152
                             "must be the same.")
1153
1154
1155
1156
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1157

1158
1159
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1160
                self._add_guided_params(sp, guided_options)
1161
1162
1163

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1164

Zhuohan Li's avatar
Zhuohan Li committed
1165
        # Add requests to the engine.
1166
        for i, prompt in enumerate(prompts):
1167
            self._add_request(
1168
                prompt,
1169
                params[i] if isinstance(params, Sequence) else params,
1170
1171
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1172
                prompt_adapter_request=prompt_adapter_request,
1173
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1174
            )
1175

1176
    def _add_request(
nunjunj's avatar
nunjunj committed
1177
        self,
1178
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1179
        params: Union[SamplingParams, PoolingParams],
1180
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1181
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1182
        priority: int = 0,
1183
1184
    ) -> None:
        request_id = str(next(self.request_counter))
1185
1186
        self.llm_engine.add_request(
            request_id,
1187
            prompt,
1188
1189
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1190
            prompt_adapter_request=prompt_adapter_request,
1191
            priority=priority,
nunjunj's avatar
nunjunj committed
1192
        )
1193

1194
    def _add_guided_params(
1195
1196
1197
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1213
1214
        return params

1215
    def _run_engine(
1216
            self, *, use_tqdm: bool
1217
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1218
1219
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1220
            num_requests = self.llm_engine.get_num_unfinished_requests()
1221
1222
1223
1224
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1225
1226
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1227
            )
1228

Zhuohan Li's avatar
Zhuohan Li committed
1229
        # Run the engine.
1230
        outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
1231
1232
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1233
1234
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1235
            for output in step_outputs:
1236
                if output.finished:
1237
1238
                    outputs.append(output)
                    if use_tqdm:
1239
1240
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1241
                            assert output.prompt_token_ids is not None
1242
1243
1244
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1245
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1246
1247
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1248
1249
1250
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1251
                        pbar.update(1)
1252

1253
1254
        if use_tqdm:
            pbar.close()
1255
1256
1257
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1258
        return sorted(outputs, key=lambda x: int(x.request_id))