llm.py 57.3 KB
Newer Older
1
import itertools
2
import warnings
3
from contextlib import contextmanager
4
5
from typing import (Any, Callable, ClassVar, Dict, List, Optional, Sequence,
                    Tuple, Type, Union, cast, overload)
6

7
import cloudpickle
8
import torch.nn as nn
9
from tqdm import tqdm
10
from typing_extensions import TypeVar, deprecated
11

12
from vllm import envs
13
14
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
                              BeamSearchSequence, get_beam_search_score)
15
from vllm.config import CompilationConfig
16
17
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
                                   TaskOption)
Joe Runde's avatar
Joe Runde committed
18
from vllm.engine.llm_engine import LLMEngine
nunjunj's avatar
nunjunj committed
19
from vllm.entrypoints.chat_utils import (ChatCompletionMessageParam,
20
                                         ChatTemplateContentFormatOption,
21
22
                                         apply_hf_chat_template,
                                         apply_mistral_chat_template,
23
24
                                         parse_chat_messages,
                                         resolve_chat_template_content_format)
25
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
26
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
27
from vllm.logger import init_logger
28
from vllm.lora.request import LoRARequest
29
30
from vllm.model_executor.guided_decoding.guided_fields import (
    GuidedDecodingRequest, LLMGuidedOptions)
31
32
33
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
                          PoolingRequestOutput, RequestOutput,
                          ScoringRequestOutput)
34
from vllm.pooling_params import PoolingParams
35
from vllm.prompt_adapter.request import PromptAdapterRequest
36
37
from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams,
                                  RequestOutputKind, SamplingParams)
38
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
39
40
                                               get_cached_tokenizer)
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
yhu422's avatar
yhu422 committed
41
from vllm.usage.usage_lib import UsageContext
42
from vllm.utils import Counter, deprecate_args, deprecate_kwargs, is_list_of
43

44
45
logger = init_logger(__name__)

46
47
_R = TypeVar("_R", default=Any)

48
49

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
50
51
52
53
54
55
56
57
58
59
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
60
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
61
62
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
63
64
65
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
66
67
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
68
69
70
71
        allowed_local_media_path: Allowing API requests to read local images
            or videos from directories specified by the server file system.
            This is a security risk. Should only be enabled in trusted
            environments.
Woosuk Kwon's avatar
Woosuk Kwon committed
72
73
74
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
75
76
77
78
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
79
        quantization: The method used to quantize the model weights. Currently,
80
            we support "awq", "gptq", and "fp8" (experimental).
81
82
83
84
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
85
86
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
87
88
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
89
90
91
92
93
94
95
96
97
98
99
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
100
101
102
103
        cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
            the model weights. This virtually increases the GPU memory space
            you can use to hold the model weights, at the cost of CPU-GPU data
            transfer for every forward pass.
104
105
106
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
107
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
108
            When a sequence has context length larger than this, we fall back
109
110
111
            to eager mode. Additionally for encoder-decoder models, if the
            sequence length of the encoder input is larger than this, we fall
            back to the eager mode.
112
113
114
        disable_custom_all_reduce: See :class:`~vllm.config.ParallelConfig`
        disable_async_output_proc: Disable async output processing.
            This may result in lower performance.
115
116
117
        hf_overrides: If a dictionary, contains arguments to be forwarded to the
            HuggingFace config. If a callable, it is called to update the
            HuggingFace config.
118
119
120
        compilation_config: Either an integer or a dictionary. If it is an
            integer, it is used as the level of compilation optimization. If it
            is a dictionary, it can specify the full compilation configuration.
121
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
122
            :ref:`engine-args`)
nunjunj's avatar
nunjunj committed
123

124
125
126
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
127
    """
128

129
    DEPRECATE_LEGACY: ClassVar[bool] = True
130
131
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

132
133
134
135
136
137
    DEPRECATE_INIT_POSARGS: ClassVar[bool] = True
    """
    A flag to toggle whether to deprecate positional arguments in
    :meth:`LLM.__init__`.
    """

138
139
140
141
142
143
144
145
146
    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

147
148
149
150
151
152
153
    @deprecate_args(
        start_index=2,  # Ignore self and model
        is_deprecated=lambda: LLM.DEPRECATE_INIT_POSARGS,
        additional_message=(
            "All positional arguments other than `model` will be "
            "replaced with keyword arguments in an upcoming version."),
    )
154
155
156
    def __init__(
        self,
        model: str,
157
        tokenizer: Optional[str] = None,
158
        tokenizer_mode: str = "auto",
159
        skip_tokenizer_init: bool = False,
160
        trust_remote_code: bool = False,
161
        allowed_local_media_path: str = "",
162
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
163
        dtype: str = "auto",
164
        quantization: Optional[str] = None,
165
        revision: Optional[str] = None,
166
        tokenizer_revision: Optional[str] = None,
167
168
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
169
        swap_space: float = 4,
170
        cpu_offload_gb: float = 0,
171
        enforce_eager: Optional[bool] = None,
172
        max_seq_len_to_capture: int = 8192,
173
        disable_custom_all_reduce: bool = False,
174
        disable_async_output_proc: bool = False,
175
        hf_overrides: Optional[HfOverrides] = None,
176
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
177
178
        # After positional args are removed, move this right below `model`
        task: TaskOption = "auto",
179
        override_pooler_config: Optional[PoolerConfig] = None,
180
        compilation_config: Optional[Union[int, Dict[str, Any]]] = None,
181
182
        **kwargs,
    ) -> None:
183
184
185
186
        '''
        LLM constructor.

        Note: if enforce_eager is unset (enforce_eager is None)
187
        it defaults to False.
188
189
        '''

190
191
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
192

193
194
195
196
197
198
199
        if "worker_cls" in kwargs:
            worker_cls = kwargs["worker_cls"]
            # if the worker_cls is not qualified string name,
            # we serialize it using cloudpickle to avoid pickling issues
            if isinstance(worker_cls, type):
                kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)

200
        if compilation_config is not None:
201
            if isinstance(compilation_config, (int, dict)):
202
203
204
205
                compilation_config_instance = CompilationConfig.from_cli(
                    str(compilation_config))
            else:
                compilation_config_instance = compilation_config
206
207
208
        else:
            compilation_config_instance = None

Zhuohan Li's avatar
Zhuohan Li committed
209
        engine_args = EngineArgs(
210
            model=model,
211
            task=task,
212
            tokenizer=tokenizer,
213
            tokenizer_mode=tokenizer_mode,
214
            skip_tokenizer_init=skip_tokenizer_init,
215
            trust_remote_code=trust_remote_code,
216
            allowed_local_media_path=allowed_local_media_path,
217
218
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
219
            quantization=quantization,
220
            revision=revision,
221
            tokenizer_revision=tokenizer_revision,
222
223
224
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
225
            cpu_offload_gb=cpu_offload_gb,
226
            enforce_eager=enforce_eager,
227
            max_seq_len_to_capture=max_seq_len_to_capture,
228
            disable_custom_all_reduce=disable_custom_all_reduce,
229
            disable_async_output_proc=disable_async_output_proc,
230
            hf_overrides=hf_overrides,
231
            mm_processor_kwargs=mm_processor_kwargs,
232
            override_pooler_config=override_pooler_config,
233
            compilation_config=compilation_config_instance,
234
235
            **kwargs,
        )
Joe Runde's avatar
Joe Runde committed
236
237
238
239
        # Logic to switch between engines is done at runtime instead of import
        # to avoid import order issues
        self.engine_class = self.get_engine_class()
        self.llm_engine = self.engine_class.from_engine_args(
yhu422's avatar
yhu422 committed
240
            engine_args, usage_context=UsageContext.LLM_CLASS)
241

242
243
        self.request_counter = Counter()

Joe Runde's avatar
Joe Runde committed
244
245
246
247
248
249
250
251
    @staticmethod
    def get_engine_class() -> Type[LLMEngine]:
        if envs.VLLM_USE_V1:
            # Lazy import: the v1 package isn't distributed
            from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
            return V1LLMEngine  # type: ignore
        return LLMEngine

252
253
254
255
256
    def get_tokenizer(self) -> AnyTokenizer:
        return self.llm_engine.get_tokenizer_group(TokenizerGroup).tokenizer

    def set_tokenizer(self, tokenizer: AnyTokenizer) -> None:
        tokenizer_group = self.llm_engine.get_tokenizer_group(TokenizerGroup)
257

258
259
260
261
        # While CachedTokenizer is dynamic, have no choice but
        # compare class name. Misjudgment will arise from
        # user-defined tokenizer started with 'Cached'
        if tokenizer.__class__.__name__.startswith("Cached"):
262
            tokenizer_group.tokenizer = tokenizer
263
        else:
264
            tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer)
265

266
267
268
269
270
271
272
    def get_default_sampling_params(self) -> SamplingParams:
        diff_sampling_param = (
            self.llm_engine.model_config.get_diff_sampling_param())
        if diff_sampling_param:
            return SamplingParams.from_optional(**diff_sampling_param)
        return SamplingParams()

273
274
275
276
277
278
279
    @overload
    def generate(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
280
        *,
281
282
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
283
284
285
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
286
287
288
    ) -> List[RequestOutput]:
        ...

289
    @overload  # LEGACY: single (prompt + optional token ids)
290
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
291
292
293
294
295
296
297
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
298
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
299
300
301
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
302
303
304
305
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
306
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
307
308
    def generate(
        self,
309
        prompts: List[str],
310
311
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
312
        prompt_token_ids: Optional[List[List[int]]] = None,
313
        use_tqdm: bool = True,
314
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
315
316
317
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
318
319
320
321
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
322
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
323
324
325
326
327
328
329
330
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
331
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
332
333
334
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
335
336
337
338
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
339
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
340
341
342
343
344
345
346
347
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
348
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
349
350
351
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
352
353
354
355
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
356
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
357
358
359
360
361
362
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
363
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
364
365
366
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
        guided_options_request: Optional[Union[LLMGuidedOptions,
                                               GuidedDecodingRequest]] = None,
367
368
369
    ) -> List[RequestOutput]:
        ...

nunjunj's avatar
nunjunj committed
370
371
372
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
373
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
374
    )
375
376
    def generate(
        self,
377
        prompts: Union[Union[PromptType, Sequence[PromptType]],
378
379
380
381
382
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
383
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
384
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
385
        guided_options_request: Optional[Union[LLMGuidedOptions,
386
387
                                               GuidedDecodingRequest]] = None,
        priority: Optional[List[int]] = None,
388
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
389
390
        """Generates the completions for the input prompts.

391
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
392
393
394
395
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
396
397
398
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
Woosuk Kwon's avatar
Woosuk Kwon committed
399
            sampling_params: The sampling parameters for text generation. If
nunjunj's avatar
nunjunj committed
400
401
402
                None, we use the default sampling parameters.
                When it is a single value, it is applied to every prompt.
                When it is a list, the list must have the same length as the
403
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
404
            use_tqdm: Whether to use tqdm to display the progress bar.
405
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
406
            prompt_adapter_request: Prompt Adapter request to use for
407
                generation, if any.
408
409
            priority: The priority of the requests, if any.
                Only applicable when priority scheduling policy is enabled.
Woosuk Kwon's avatar
Woosuk Kwon committed
410
411

        Returns:
nunjunj's avatar
nunjunj committed
412
            A list of ``RequestOutput`` objects containing the
413
            generated completions in the same order as the input prompts.
414
415
416
417
418

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
419
        """
420
421
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "generate":
422
            messages = [
423
                "LLM.generate() is only supported for (conditional) generation "
424
425
426
                "models (XForCausalLM, XForConditionalGeneration).",
            ]

427
428
429
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "generate" in supported_runner_types:
430
                messages.append(
431
432
433
                    "Your model supports the 'generate' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task generate`.")
434
435

            raise ValueError(" ".join(messages))
436

437
        if prompt_token_ids is not None:
438
            parsed_prompts = self._convert_v1_inputs(
439
440
441
442
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
443
444
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
445

446
447
448
449
450
451
452
453
        if isinstance(guided_options_request, dict):
            if len(guided_options_request) > 1:
                raise ValueError(
                    "You can only use one guided decoding but multiple is "
                    f"specified: {guided_options_request}")
            guided_options_request = GuidedDecodingRequest(
                **guided_options_request)

454
455
        if sampling_params is None:
            # Use default sampling params.
456
            sampling_params = self.get_default_sampling_params()
457

458
        self._validate_and_add_requests(
459
            prompts=parsed_prompts,
460
461
            params=sampling_params,
            lora_request=lora_request,
462
            prompt_adapter_request=prompt_adapter_request,
463
464
            guided_options=guided_options_request,
            priority=priority)
465

466
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
467
        return self.engine_class.validate_outputs(outputs, RequestOutput)
468

469
    def collective_rpc(self,
470
                       method: Union[str, Callable[..., _R]],
471
472
                       timeout: Optional[float] = None,
                       args: Tuple = (),
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
                       kwargs: Optional[Dict[str, Any]] = None) -> List[_R]:
        """
        Execute an RPC call on all workers.

        Args:
            method: Name of the worker method to execute, or a callable that
                is serialized and sent to all workers to execute.

                If the method is a callable, it should accept an additional
                `self` argument, in addition to the arguments passed in `args`
                and `kwargs`. The `self` argument will be the worker object.
            timeout: Maximum time in seconds to wait for execution. Raises a
                :exc:`TimeoutError` on timeout. `None` means wait indefinitely.
            args: Positional arguments to pass to the worker method.
            kwargs: Keyword arguments to pass to the worker method.

        Returns:
            A list containing the results from each worker.
        
        Note:
            It is recommended to use this API to only pass control messages,
            and set up data-plane communication to pass data.
        """
        executor = self.llm_engine.model_executor
        return executor.collective_rpc(method, timeout, args, kwargs)

    def apply_model(self, func: Callable[[nn.Module], _R]) -> list[_R]:
500
        """
501
502
        Run a function directly on the model inside each worker,
        returning the result for each of them.
503
        """
504
505
        executor = self.llm_engine.model_executor
        return executor.apply_model(func)
506

507
508
    def beam_search(
        self,
509
        prompts: List[Union[TokensPrompt, TextPrompt]],
510
        params: BeamSearchParams,
511
512
513
514
515
516
517
    ) -> List[BeamSearchOutput]:
        """
        Generate sequences using beam search.

        Args:
            prompts: A list of prompts. Each prompt can be a string or a list
                of token IDs.
518
519
            params: The beam search parameters.

520
521
522
523
        TODO: how does beam search work together with length penalty, frequency
        penalty, and stopping criteria, etc.?
        """

524
525
526
527
        beam_width = params.beam_width
        max_tokens = params.max_tokens
        temperature = params.temperature
        ignore_eos = params.ignore_eos
528
529
530
531
532
533
        length_penalty = params.length_penalty

        def sort_beams_key(x: BeamSearchSequence) -> float:
            return get_beam_search_score(x.tokens, x.cum_logprob,
                                         tokenizer.eos_token_id,
                                         length_penalty)
534

535
536
537
538
539
540
        tokenizer = self.get_tokenizer()
        # generate 2 * beam_width candidates at each step
        # following the huggingface transformers implementation
        # at https://github.com/huggingface/transformers/blob/e15687fffe5c9d20598a19aeab721ae0a7580f8a/src/transformers/generation/beam_search.py#L534 # noqa
        beam_search_params = SamplingParams(logprobs=2 * beam_width,
                                            max_tokens=1,
541
                                            temperature=temperature)
542
543
544
        instances: List[BeamSearchInstance] = []

        for prompt in prompts:
545
546
547
548
            if is_token_prompt(prompt):
                prompt_tokens = prompt["prompt_token_ids"]
            else:
                prompt_tokens = tokenizer.encode(prompt["prompt"])
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
            instances.append(BeamSearchInstance(prompt_tokens))

        for _ in range(max_tokens):
            all_beams: List[BeamSearchSequence] = list(
                sum((instance.beams for instance in instances), []))
            pos = [0] + list(
                itertools.accumulate(
                    len(instance.beams) for instance in instances))
            instance_start_and_end: List[Tuple[int, int]] = list(
                zip(pos[:-1], pos[1:]))

            if len(all_beams) == 0:
                break

            prompts_batch = [
                TokensPrompt(prompt_token_ids=beam.tokens)
                for beam in all_beams
            ]

            # only runs for one step
            # we don't need to use tqdm here
            output = self.generate(prompts_batch,
                                   sampling_params=beam_search_params,
                                   use_tqdm=False)

            for (start, end), instance in zip(instance_start_and_end,
                                              instances):
                instance_new_beams = []
                for i in range(start, end):
                    current_beam = all_beams[i]
                    result = output[i]

                    if result.outputs[0].logprobs is not None:
                        # if `result.outputs[0].logprobs` is None, it means
                        # the sequence is completed because of the max-model-len
                        # or abortion. we don't need to add it to the new beams.
                        logprobs = result.outputs[0].logprobs[0]
                        for token_id, logprob_obj in logprobs.items():
                            new_beam = BeamSearchSequence(
                                tokens=current_beam.tokens + [token_id],
589
                                logprobs=current_beam.logprobs + [logprobs],
590
591
592
593
594
595
596
597
598
                                cum_logprob=current_beam.cum_logprob +
                                logprob_obj.logprob)

                            if token_id == tokenizer.eos_token_id and \
                                not ignore_eos:
                                instance.completed.append(new_beam)
                            else:
                                instance_new_beams.append(new_beam)
                sorted_beams = sorted(instance_new_beams,
599
                                      key=sort_beams_key,
600
601
602
603
604
605
606
                                      reverse=True)
                instance.beams = sorted_beams[:beam_width]

        outputs = []
        for instance in instances:
            instance.completed.extend(instance.beams)
            sorted_completed = sorted(instance.completed,
607
                                      key=sort_beams_key,
608
609
610
611
612
613
614
615
616
                                      reverse=True)
            best_beams = sorted_completed[:beam_width]

            for beam in best_beams:
                beam.text = tokenizer.decode(beam.tokens)
            outputs.append(BeamSearchOutput(sequences=best_beams))

        return outputs

nunjunj's avatar
nunjunj committed
617
618
    def chat(
        self,
619
620
        messages: Union[List[ChatCompletionMessageParam],
                        List[List[ChatCompletionMessageParam]]],
nunjunj's avatar
nunjunj committed
621
622
623
624
625
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        chat_template: Optional[str] = None,
626
        chat_template_content_format: ChatTemplateContentFormatOption = "auto",
627
        add_generation_prompt: bool = True,
628
        continue_final_message: bool = False,
629
        tools: Optional[List[Dict[str, Any]]] = None,
630
        mm_processor_kwargs: Optional[Dict[str, Any]] = None,
nunjunj's avatar
nunjunj committed
631
632
    ) -> List[RequestOutput]:
        """
633
        Generate responses for a chat conversation.
nunjunj's avatar
nunjunj committed
634

635
636
637
638
639
640
        The chat conversation is converted into a text prompt using the
        tokenizer and calls the :meth:`generate` method to generate the
        responses.

        Multi-modal inputs can be passed in the same way you would pass them
        to the OpenAI API.
nunjunj's avatar
nunjunj committed
641
642

        Args:
643
644
645
646
647
            messages: A list of conversations or a single conversation.

              - Each conversation is represented as a list of messages.
              - Each message is a dictionary with 'role' and 'content' keys.

nunjunj's avatar
nunjunj committed
648
649
650
651
652
653
654
655
656
            sampling_params: The sampling parameters for text generation.
                If None, we use the default sampling parameters. When it
                is a single value, it is applied to every prompt. When it
                is a list, the list must have the same length as the
                prompts and it is paired one by one with the prompt.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            chat_template: The template to use for structuring the chat.
              If not provided, the model's default chat template will be used.
657
658
659
660
661
662
663
664
            chat_template_content_format: The format to render message content.

              - "string" will render the content as a string.
                Example: ``"Who are you?"``
              - "openai" will render the content as a list of dictionaries,
                similar to OpenAI schema.
                Example: ``[{"type": "text", "text": "Who are you?"}]``

665
            add_generation_prompt: If True, adds a generation template
nunjunj's avatar
nunjunj committed
666
                to each message.
667
            continue_final_message: If True, continues the final message in
668
669
                the conversation instead of starting a new one. Cannot be
                ``True`` if ``add_generation_prompt`` is also ``True``.
670
671
            mm_processor_kwargs: Multimodal processor kwarg overrides for this
                chat request. Only used for offline requests.
nunjunj's avatar
nunjunj committed
672
673
674
675
676

        Returns:
            A list of ``RequestOutput`` objects containing the generated
            responses in the same order as the input messages.
        """
677
        list_of_messages: List[List[ChatCompletionMessageParam]]
nunjunj's avatar
nunjunj committed
678

679
680
681
        # Handle multi and single conversations
        if is_list_of(messages, list):
            # messages is List[List[...]]
682
683
            list_of_messages = cast(List[List[ChatCompletionMessageParam]],
                                    messages)
684
        else:
685
            # messages is List[...]
686
687
688
            list_of_messages = [
                cast(List[ChatCompletionMessageParam], messages)
            ]
689

690
691
692
693
694
695
696
697
        tokenizer = self.get_tokenizer()
        model_config = self.llm_engine.get_model_config()
        resolved_content_format = resolve_chat_template_content_format(
            chat_template,
            chat_template_content_format,
            tokenizer,
        )

698
699
700
        prompts: List[Union[TokensPrompt, TextPrompt]] = []

        for msgs in list_of_messages:
701
702
703
            # NOTE: _parse_chat_message_content_parts() currently doesn't
            # handle mm_processor_kwargs, since there is no implementation in
            # the chat message parsing for it.
704
            conversation, mm_data = parse_chat_messages(
705
706
707
708
709
                msgs,
                model_config,
                tokenizer,
                content_format=resolved_content_format,
            )
710
711
712
713
714
715
716
717

            prompt_data: Union[str, List[int]]
            if isinstance(tokenizer, MistralTokenizer):
                prompt_data = apply_mistral_chat_template(
                    tokenizer,
                    messages=msgs,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
718
                    continue_final_message=continue_final_message,
719
720
721
722
723
724
725
726
                    tools=tools,
                )
            else:
                prompt_data = apply_hf_chat_template(
                    tokenizer,
                    conversation=conversation,
                    chat_template=chat_template,
                    add_generation_prompt=add_generation_prompt,
727
                    continue_final_message=continue_final_message,
728
729
730
731
732
733
734
735
736
737
738
739
                    tools=tools,
                )

            prompt: Union[TokensPrompt, TextPrompt]
            if is_list_of(prompt_data, int):
                prompt = TokensPrompt(prompt_token_ids=prompt_data)
            else:
                prompt = TextPrompt(prompt=prompt_data)

            if mm_data is not None:
                prompt["multi_modal_data"] = mm_data

740
741
742
            if mm_processor_kwargs is not None:
                prompt["mm_processor_kwargs"] = mm_processor_kwargs

743
            prompts.append(prompt)
744

nunjunj's avatar
nunjunj committed
745
        return self.generate(
746
            prompts,
747
            sampling_params=sampling_params,
nunjunj's avatar
nunjunj committed
748
749
750
751
            use_tqdm=use_tqdm,
            lora_request=lora_request,
        )

752
753
754
755
756
757
758
    @overload
    def encode(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
759
        *,
760
761
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
762
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
763
764
765
    ) -> List[PoolingRequestOutput]:
        ...

766
    @overload  # LEGACY: single (prompt + optional token ids)
767
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
768
769
770
771
772
773
774
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
775
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
776
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
777
    ) -> List[PoolingRequestOutput]:
778
        ...
779

780
    @overload  # LEGACY: multi (prompt + optional token ids)
781
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
782
783
    def encode(
        self,
784
        prompts: List[str],
785
        pooling_params: Optional[Union[PoolingParams,
786
                                       Sequence[PoolingParams]]] = None,
787
788
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
789
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
790
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
791
    ) -> List[PoolingRequestOutput]:
792
793
794
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
795
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
796
797
798
799
800
801
802
803
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
804
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
805
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
806
    ) -> List[PoolingRequestOutput]:
807
808
809
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
810
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
811
812
813
814
815
816
817
818
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
819
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
820
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
821
    ) -> List[PoolingRequestOutput]:
822
823
824
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
825
    @deprecated("'prompt_token_ids' will become part of 'prompts'")
826
827
828
829
830
831
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
832
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
833
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
834
    ) -> List[PoolingRequestOutput]:
835
836
        ...

nunjunj's avatar
nunjunj committed
837
838
839
    @deprecate_kwargs(
        "prompt_token_ids",
        is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
840
        additional_message="Please use the 'prompts' parameter instead.",
nunjunj's avatar
nunjunj committed
841
    )
842
843
    def encode(
        self,
844
        prompts: Union[Union[PromptType, Sequence[PromptType]],
845
846
847
848
849
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
850
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
851
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
852
    ) -> List[PoolingRequestOutput]:
853
854
        """Apply pooling to the hidden states corresponding to the input
        prompts.
855

856
        This class automatically batches the given prompts, considering
857
858
859
860
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
861
862
863
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
864
865
866
867
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
nunjunj's avatar
nunjunj committed
868
            prompt_adapter_request: Prompt Adapter request to use for
869
                generation, if any.
870
871

        Returns:
872
            A list of ``PoolingRequestOutput`` objects containing the
873
            pooled hidden states in the same order as the input prompts.
874
875
876
877
878

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
879
        """
880
881
882
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.encode() is only supported for pooling models."]
883

884
885
886
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
887
                messages.append(
888
889
890
891
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
892
893

            raise ValueError(" ".join(messages))
894

895
        if prompt_token_ids is not None:
896
            parsed_prompts = self._convert_v1_inputs(
897
898
899
900
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
            )
        else:
901
902
            parsed_prompts = cast(Union[PromptType, Sequence[PromptType]],
                                  prompts)
903

904
905
906
907
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

908
        self._validate_and_add_requests(
909
            prompts=parsed_prompts,
910
911
            params=pooling_params,
            lora_request=lora_request,
912
            prompt_adapter_request=prompt_adapter_request,
913
914
        )

915
        outputs = self._run_engine(use_tqdm=use_tqdm)
Joe Runde's avatar
Joe Runde committed
916
        return self.engine_class.validate_outputs(outputs,
917
                                                  PoolingRequestOutput)
918

919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
    def embed(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[EmbeddingRequestOutput]:
        """
        Generate an embedding vector for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``EmbeddingRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "embed":
            raise ValueError(
                "Embedding API is only enabled for `--task embed`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [EmbeddingRequestOutput.from_base(item) for item in items]

    def classify(
        self,
        prompts: Union[PromptType, Sequence[PromptType]],
        /,
        *,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
    ) -> List[ClassificationRequestOutput]:
        """
        Generate class logits for each prompt.

        This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: The prompts to the LLM. You may pass a sequence of prompts
                for batch inference. See :class:`~vllm.inputs.PromptType`
                for more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
            A list of ``ClassificationRequestOutput`` objects containing the
            embedding vectors in the same order as the input prompts.
        """
        if self.llm_engine.model_config.task != "classify":
            raise ValueError(
                "Classification API is only enabled for `--task classify`")

        items = self.encode(prompts,
                            use_tqdm=use_tqdm,
                            lora_request=lora_request,
                            prompt_adapter_request=prompt_adapter_request)

        return [ClassificationRequestOutput.from_base(item) for item in items]

999
1000
1001
1002
1003
    def score(
        self,
        text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]],
        /,
1004
        *,
1005
1006
1007
1008
        truncate_prompt_tokens: Optional[int] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1009
1010
    ) -> List[ScoringRequestOutput]:
        """Generate similarity scores for all pairs ``<text,text_pair>``.
1011

1012
1013
1014
1015
        The inputs can be ``1 -> 1``, ``1 -> N`` or ``N -> N``.
        In the ``1 - N`` case the ``text_1`` sentence will be replicated ``N``
        times to pair with the ``text_2`` sentences.
        The input pairs are used to build a list of prompts for the
1016
1017
1018
1019
1020
1021
        cross encoder model. This class automatically batches the prompts,
        considering the memory constraint. For the best performance, put all
        of your texts into a single list and pass it to this method.

        Args:
            text_1: can be a single prompt or a list of prompts, in which
1022
                case it has to have the same length as the ``text_2`` list
1023
1024
1025
1026
1027
1028
1029
1030
1031
            text_2: The texts to pair with the query to form the input
                to the LLM. See :class:`~vllm.inputs.PromptType` for
                more details about the format of each prompts.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.
            prompt_adapter_request: Prompt Adapter request to use for
                generation, if any.

        Returns:
1032
            A list of ``ScoringRequestOutput`` objects containing the
1033
1034
            generated scores in the same order as the input prompts.
        """
1035
1036
1037
        runner_type = self.llm_engine.model_config.runner_type
        if runner_type != "pooling":
            messages = ["LLM.score() is only supported for pooling models."]
1038

1039
1040
1041
            supported_runner_types = self.llm_engine.model_config \
                .supported_runner_types
            if "pooling" in supported_runner_types:
1042
                messages.append(
1043
1044
1045
1046
                    "Your model supports the 'pooling' runner, but is "
                    f"currently initialized for the '{runner_type}' runner. "
                    "Please initialize vLLM using `--task embed`, "
                    "`--task classify`, `--task score` etc.")
1047
1048
1049
1050

            raise ValueError(" ".join(messages))

        if not self.llm_engine.model_config.is_cross_encoder:
1051
            raise ValueError("Your model does not support cross encoding")
1052
1053
        if self.llm_engine.model_config.task != "score":
            raise ValueError("Score API is only enabled for `--task score`")
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123

        tokenizer = self.llm_engine.get_tokenizer()

        if isinstance(tokenizer, MistralTokenizer):
            raise ValueError(
                "MistralTokenizer not supported for cross-encoding")

        # the tokenizer for models such as
        # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing
        # lists of tokens to the `text` and `text_pair` kwargs
        def ensure_str(prompt: SingletonPrompt):
            if isinstance(prompt, dict):
                if "multi_modal_data" in prompt:
                    raise ValueError("Multi-modal prompt is not "
                                     "supported for cross encoding")
                elif "prompt_token_ids" in prompt:
                    prompt = tokenizer.decode(
                        cast(TokensPrompt, prompt)["prompt_token_ids"])
                elif "prompt" in prompt:
                    prompt = cast(TextPrompt, prompt)["prompt"]
            assert type(prompt) is str
            return prompt

        if isinstance(text_1, (str, dict)):
            # Convert a single prompt to a list.
            text_1 = [text_1]
        text_1 = [ensure_str(t) for t in text_1]

        if isinstance(text_2, (str, dict)):
            # Convert a single prompt to a list.
            text_2 = [text_2]
        text_2 = [ensure_str(t) for t in text_2]

        if len(text_1) > 1 and len(text_1) != len(text_2):
            raise ValueError("Input lengths must be either 1:1, 1:N or N:N")
        if len(text_1) == 0:
            raise ValueError("At least one text element must be given")
        if len(text_2) == 0:
            raise ValueError("At least one text_pair element must be given")

        if len(text_1) == 1:
            text_1 = text_1 * len(text_2)

        input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)]
        pooling_params = PoolingParams()

        tokenization_kwargs: Dict[str, Any] = {}
        if truncate_prompt_tokens is not None:
            tokenization_kwargs["truncation"] = True
            tokenization_kwargs["max_length"] = truncate_prompt_tokens

        parsed_prompts = []

        for q, t in input_pairs:
            prompt_inputs = tokenizer(text=q,
                                      text_pair=t,
                                      **tokenization_kwargs)
            engine_prompt = TokensPrompt(
                prompt_token_ids=prompt_inputs["input_ids"],
                token_type_ids=prompt_inputs.get("token_type_ids"))
            parsed_prompts.append(engine_prompt)

        self._validate_and_add_requests(
            prompts=parsed_prompts,
            params=pooling_params,
            lora_request=lora_request,
            prompt_adapter_request=prompt_adapter_request,
        )

        outputs = self._run_engine(use_tqdm=use_tqdm)
1124
1125
1126
1127
        items = self.engine_class.validate_outputs(outputs,
                                                   PoolingRequestOutput)

        return [ScoringRequestOutput.from_base(item) for item in items]
1128

1129
1130
1131
1132
1133
1134
    def start_profile(self) -> None:
        self.llm_engine.start_profile()

    def stop_profile(self) -> None:
        self.llm_engine.stop_profile()

1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
    def sleep(self, level: int = 1):
        """
        Put the engine to sleep. The engine should not process any requests.
        The caller should guarantee that no requests are being processed
        during the sleep period, before `wake_up` is called.

        :param level: The sleep level. Level 1 sleep will offload the model 
            weights and discard the kv cache. The content of kv cache is 
            forgotten. Level 1 sleep is good for sleeping and waking up the 
            engine to run the same model again. The model weights are backed 
            up in CPU memory. Please make sure there's enough CPU memory to 
            store the model weights. Level 2 sleep will discard both the model 
            weights and the kv cache. The content of both the model weights 
            and kv cache is forgotten. Level 2 sleep is good for sleeping and 
            waking up the engine to run a different model or update the model, 
            where previous model weights are not needed. It reduces CPU memory 
            pressure.
        """
        self.llm_engine.sleep(level=level)

    def wake_up(self):
        self.llm_engine.wake_up()

1158
1159
    # LEGACY
    def _convert_v1_inputs(
1160
1161
        self,
        prompts: Optional[Union[str, List[str]]],
1162
1163
1164
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
    ):
        # skip_tokenizer_init is now checked in engine
1165

1166
1167
1168
1169
1170
1171
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
1172

1173
        num_requests = None
1174
1175
        if prompts is not None:
            num_requests = len(prompts)
1176
1177
1178
1179
1180
1181
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

1182
            num_requests = len(prompt_token_ids)
1183
1184
1185
1186
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

1187
        parsed_prompts: List[PromptType] = []
1188
        for i in range(num_requests):
1189
            item: PromptType
1190

1191
            if prompts is not None:
1192
1193
1194
                item = TextPrompt(prompt=prompts[i])
            elif prompt_token_ids is not None:
                item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
1195
            else:
1196
                raise AssertionError
1197

1198
            parsed_prompts.append(item)
1199

1200
        return parsed_prompts
1201
1202
1203

    def _validate_and_add_requests(
        self,
1204
        prompts: Union[PromptType, Sequence[PromptType]],
1205
1206
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
1207
        lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
1208
        prompt_adapter_request: Optional[PromptAdapterRequest],
1209
        guided_options: Optional[GuidedDecodingRequest] = None,
1210
        priority: Optional[List[int]] = None,
1211
    ) -> None:
1212
1213
1214
1215
1216
1217
1218
1219
        if guided_options is not None:
            warnings.warn(
                "guided_options_request is deprecated, use "
                "SamplingParams.guided_decoding instead",
                DeprecationWarning,
                stacklevel=2,
            )

1220
        if isinstance(prompts, (str, dict)):
1221
            # Convert a single prompt to a list.
1222
            prompts = [prompts]
1223

1224
        num_requests = len(prompts)
1225
1226
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
1227
                             "must be the same.")
1228
1229
1230
1231
        if isinstance(lora_request,
                      list) and len(lora_request) != num_requests:
            raise ValueError("The lengths of prompts and lora_request "
                             "must be the same.")
1232

1233
1234
        for sp in params if isinstance(params, list) else (params, ):
            if isinstance(sp, SamplingParams):
1235
                self._add_guided_params(sp, guided_options)
1236
1237
1238

                # We only care about the final output
                sp.output_kind = RequestOutputKind.FINAL_ONLY
1239

Zhuohan Li's avatar
Zhuohan Li committed
1240
        # Add requests to the engine.
1241
        for i, prompt in enumerate(prompts):
1242
            self._add_request(
1243
                prompt,
1244
                params[i] if isinstance(params, Sequence) else params,
1245
1246
                lora_request=lora_request[i] if isinstance(
                    lora_request, Sequence) else lora_request,
nunjunj's avatar
nunjunj committed
1247
                prompt_adapter_request=prompt_adapter_request,
1248
                priority=priority[i] if priority else 0,
nunjunj's avatar
nunjunj committed
1249
            )
1250

1251
    def _add_request(
nunjunj's avatar
nunjunj committed
1252
        self,
1253
        prompt: PromptType,
nunjunj's avatar
nunjunj committed
1254
        params: Union[SamplingParams, PoolingParams],
1255
        lora_request: Optional[LoRARequest] = None,
nunjunj's avatar
nunjunj committed
1256
        prompt_adapter_request: Optional[PromptAdapterRequest] = None,
1257
        priority: int = 0,
1258
1259
    ) -> None:
        request_id = str(next(self.request_counter))
1260
1261
        self.llm_engine.add_request(
            request_id,
1262
            prompt,
1263
1264
            params,
            lora_request=lora_request,
nunjunj's avatar
nunjunj committed
1265
            prompt_adapter_request=prompt_adapter_request,
1266
            priority=priority,
nunjunj's avatar
nunjunj committed
1267
        )
1268

1269
    def _add_guided_params(
1270
1271
1272
            self,
            params: SamplingParams,
            guided_options: Optional[GuidedDecodingRequest] = None):
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
        if guided_options is None:
            return params

        if params.guided_decoding is not None:
            raise ValueError("Cannot set both guided_options_request and"
                             "params.guided_decoding.")

        params.guided_decoding = GuidedDecodingParams(
            json=guided_options.guided_json,
            regex=guided_options.guided_regex,
            choice=guided_options.guided_choice,
            grammar=guided_options.guided_grammar,
            json_object=guided_options.guided_json_object,
            backend=guided_options.guided_decoding_backend,
            whitespace_pattern=guided_options.guided_whitespace_pattern)
1288
1289
        return params

1290
    def _run_engine(
1291
            self, *, use_tqdm: bool
1292
    ) -> List[Union[RequestOutput, PoolingRequestOutput]]:
1293
1294
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
1295
            num_requests = self.llm_engine.get_num_unfinished_requests()
1296
1297
1298
1299
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
1300
1301
                postfix=(f"est. speed input: {0:.2f} toks/s, "
                         f"output: {0:.2f} toks/s"),
1302
            )
1303

Zhuohan Li's avatar
Zhuohan Li committed
1304
        # Run the engine.
1305
        outputs: List[Union[RequestOutput, PoolingRequestOutput]] = []
1306
1307
        total_in_toks = 0
        total_out_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
1308
1309
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
1310
            for output in step_outputs:
1311
                if output.finished:
1312
1313
                    outputs.append(output)
                    if use_tqdm:
1314
1315
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
1316
                            assert output.prompt_token_ids is not None
1317
1318
1319
                            total_in_toks += len(output.prompt_token_ids)
                            in_spd = total_in_toks / pbar.format_dict["elapsed"]
                            total_out_toks += sum(
1320
                                len(stp.token_ids) for stp in output.outputs)
nunjunj's avatar
nunjunj committed
1321
1322
                            out_spd = (total_out_toks /
                                       pbar.format_dict["elapsed"])
1323
1324
1325
                            pbar.postfix = (
                                f"est. speed input: {in_spd:.2f} toks/s, "
                                f"output: {out_spd:.2f} toks/s")
1326
                        pbar.update(1)
1327

1328
1329
        if use_tqdm:
            pbar.close()
1330
1331
1332
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
1333
        return sorted(outputs, key=lambda x: int(x.request_id))