llm.py 23.1 KB
Newer Older
1
2
from contextlib import contextmanager
from typing import ClassVar, List, Optional, Sequence, Union, cast, overload
3
4

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
9
10
11
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
                         TextTokensPrompt, TokensPrompt,
                         parse_and_batch_prompt)
12
from vllm.logger import init_logger
13
from vllm.lora.request import LoRARequest
14
15
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
Woosuk Kwon's avatar
Woosuk Kwon committed
16
from vllm.sampling_params import SamplingParams
17
from vllm.sequence import MultiModalData
yhu422's avatar
yhu422 committed
18
from vllm.usage.usage_lib import UsageContext
19
from vllm.utils import Counter, deprecate_kwargs
20

21
22
logger = init_logger(__name__)

23
24

class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27
28
29
30
31
32
33
34
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    Args:
        model: The name or path of a HuggingFace Transformers model.
35
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
36
37
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
38
39
40
        skip_tokenizer_init: If true, skip initialization of tokenizer and
            detokenizer. Expect valid prompt_token_ids and None for prompt
            from the input.
41
42
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
43
44
45
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
46
47
48
49
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
50
        quantization: The method used to quantize the model weights. Currently,
51
52
53
54
55
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
56
57
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
58
59
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
60
61
62
63
64
65
66
67
68
69
70
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
71
72
73
74
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
75
76
77
            When a sequence has context length larger than this, we fall back
            to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead).
        max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
78
79
            When a sequence has context length larger than this, we fall back
            to eager mode.
80
        disable_custom_all_reduce: See ParallelConfig
81
82
83
84
85
86
        **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See
            :ref:`engine_args`)
    
    Note:
        This class is intended to be used for offline inference. For online
        serving, use the :class:`~vllm.AsyncLLMEngine` class instead.
Woosuk Kwon's avatar
Woosuk Kwon committed
87
    """
88

89
90
91
92
93
94
95
96
97
98
99
100
    DEPRECATE_LEGACY: ClassVar[bool] = False
    """A flag to toggle whether to deprecate the legacy generate/encode API."""

    @classmethod
    @contextmanager
    def deprecate_legacy_api(cls):
        cls.DEPRECATE_LEGACY = True

        yield

        cls.DEPRECATE_LEGACY = False

101
102
103
    def __init__(
        self,
        model: str,
104
        tokenizer: Optional[str] = None,
105
        tokenizer_mode: str = "auto",
106
        skip_tokenizer_init: bool = False,
107
        trust_remote_code: bool = False,
108
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
109
        dtype: str = "auto",
110
        quantization: Optional[str] = None,
111
        revision: Optional[str] = None,
112
        tokenizer_revision: Optional[str] = None,
113
114
115
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
116
        enforce_eager: bool = False,
117
118
        max_context_len_to_capture: Optional[int] = None,
        max_seq_len_to_capture: int = 8192,
119
        disable_custom_all_reduce: bool = False,
120
121
122
123
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
Zhuohan Li's avatar
Zhuohan Li committed
124
        engine_args = EngineArgs(
125
            model=model,
126
            tokenizer=tokenizer,
127
            tokenizer_mode=tokenizer_mode,
128
            skip_tokenizer_init=skip_tokenizer_init,
129
            trust_remote_code=trust_remote_code,
130
131
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
132
            quantization=quantization,
133
            revision=revision,
134
            tokenizer_revision=tokenizer_revision,
135
136
137
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
138
139
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
140
            max_seq_len_to_capture=max_seq_len_to_capture,
141
            disable_custom_all_reduce=disable_custom_all_reduce,
142
143
            **kwargs,
        )
yhu422's avatar
yhu422 committed
144
145
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
146
147
        self.request_counter = Counter()

148
    def get_tokenizer(
149
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
150
        return self.llm_engine.tokenizer.tokenizer
151

152
153
154
155
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
156
        self.llm_engine.tokenizer.tokenizer = tokenizer
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
    @overload  # LEGACY: single (prompt + optional token ids)
    def generate(
        self,
        prompts: str,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (prompt + optional token ids)
172
173
    def generate(
        self,
174
        prompts: List[str],
175
176
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
177
        prompt_token_ids: Optional[List[List[int]]] = None,
178
        use_tqdm: bool = True,
179
        lora_request: Optional[LoRARequest] = None,
180
        multi_modal_data: Optional[MultiModalData] = None,
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[str] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def generate(
        self,
        prompts: Optional[List[str]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        List[SamplingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[RequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def generate(
        self,
        prompts: None,
        sampling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[RequestOutput]:
        ...

    @overload
    def generate(
        self,
        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
    ) -> List[RequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      "multi_modal_data",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def generate(
        self,
        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
253
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
254
255
        """Generates the completions for the input prompts.

256
        This class automatically batches the given prompts, considering
Woosuk Kwon's avatar
Woosuk Kwon committed
257
258
259
260
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
261
            inputs: A list of inputs to generate completions for.
Woosuk Kwon's avatar
Woosuk Kwon committed
262
            sampling_params: The sampling parameters for text generation. If
263
264
265
266
                None, we use the default sampling parameters. 
                When it is a single value, it is applied to every prompt. 
                When it is a list, the list must have the same length as the 
                prompts and it is paired one by one with the prompt.
Woosuk Kwon's avatar
Woosuk Kwon committed
267
            use_tqdm: Whether to use tqdm to display the progress bar.
268
            lora_request: LoRA request to use for generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
269
270

        Returns:
271
272
            A list of `RequestOutput` objects containing the
            generated completions in the same order as the input prompts.
273
274
275
276
277

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
278
        """
279
280
281
282
283
284
285
286
287
288
289
        if prompt_token_ids is not None or multi_modal_data is not None:
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
                multi_modal_data=multi_modal_data,
            )
        else:
            inputs = cast(
                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                prompts)

290
291
292
293
        if sampling_params is None:
            # Use default sampling params.
            sampling_params = SamplingParams()

294
295
296
297
        self._validate_and_add_requests(
            inputs=inputs,
            params=sampling_params,
            lora_request=lora_request,
298
299
        )

300
301
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, RequestOutput)
302

303
304
305
306
307
308
309
310
311
312
313
314
    @overload  # LEGACY: single (prompt + optional token ids)
    def encode(
        self,
        prompts: str,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[List[int]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[EmbeddingRequestOutput]:
        ...
315

316
    @overload  # LEGACY: multi (prompt + optional token ids)
317
318
    def encode(
        self,
319
        prompts: List[str],
320
        pooling_params: Optional[Union[PoolingParams,
321
                                       Sequence[PoolingParams]]] = None,
322
323
324
325
        prompt_token_ids: Optional[List[List[int]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[str] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[int],
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: multi (token ids + optional prompt)
    def encode(
        self,
        prompts: Optional[List[str]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        *,
        prompt_token_ids: List[List[int]],
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload  # LEGACY: single or multi token ids [pos-only]
    def encode(
        self,
        prompts: None,
        pooling_params: None,
        prompt_token_ids: Union[List[int], List[List[int]]],
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
    ) -> List[EmbeddingRequestOutput]:
        ...

    @overload
    def encode(
        self,
        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
        /,  # We may enable `inputs` keyword after removing the old API
        *,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
    ) -> List[EmbeddingRequestOutput]:
        ...

    @deprecate_kwargs("prompts",
                      "prompt_token_ids",
                      "multi_modal_data",
                      is_deprecated=lambda: LLM.DEPRECATE_LEGACY,
                      additional_message="Please use the 'inputs' parameter "
                      "instead.")
    def encode(
        self,
        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                       Optional[Union[str, List[str]]]] = None,
        pooling_params: Optional[Union[PoolingParams,
                                       Sequence[PoolingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[LoRARequest] = None,
        multi_modal_data: Optional[MultiModalData] = None,
398
399
400
    ) -> List[EmbeddingRequestOutput]:
        """Generates the completions for the input prompts.

401
        This class automatically batches the given prompts, considering
402
403
404
405
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
406
407
408
            inputs: The inputs to the LLM. You may pass a sequence of inputs for
                batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
                for more details about the format of each input.
409
410
411
412
413
414
415
416
            pooling_params: The pooling parameters for pooling. If None, we
                use the default pooling parameters.
            use_tqdm: Whether to use tqdm to display the progress bar.
            lora_request: LoRA request to use for generation, if any.

        Returns:
            A list of `EmbeddingRequestOutput` objects containing the
            generated embeddings in the same order as the input prompts.
417
418
419
420
421

        Note:
            Using ``prompts`` and ``prompt_token_ids`` as keyword parameters is
            considered legacy and may be deprecated in the future. You should
            instead pass them via the ``inputs`` parameter.
422
        """
423
424
425
426
427
428
429
430
431
432
433
        if prompt_token_ids is not None or multi_modal_data is not None:
            inputs = self._convert_v1_inputs(
                prompts=cast(Optional[Union[str, List[str]]], prompts),
                prompt_token_ids=prompt_token_ids,
                multi_modal_data=multi_modal_data,
            )
        else:
            inputs = cast(
                Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                prompts)

434
435
436
437
        if pooling_params is None:
            # Use default pooling params.
            pooling_params = PoolingParams()

438
439
440
441
        self._validate_and_add_requests(
            inputs=inputs,
            params=pooling_params,
            lora_request=lora_request,
442
443
        )

444
445
        outputs = self._run_engine(use_tqdm=use_tqdm)
        return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput)
446

447
448
    # LEGACY
    def _convert_v1_inputs(
449
450
        self,
        prompts: Optional[Union[str, List[str]]],
451
452
453
454
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]],
        multi_modal_data: Optional[MultiModalData],
    ):
        # skip_tokenizer_init is now checked in engine
455

456
457
458
459
460
461
        if prompts is not None:
            prompts = [p["content"] for p in parse_and_batch_prompt(prompts)]
        if prompt_token_ids is not None:
            prompt_token_ids = [
                p["content"] for p in parse_and_batch_prompt(prompt_token_ids)
            ]
462

463
        num_requests = None
464
465
        if prompts is not None:
            num_requests = len(prompts)
466
467
468
469
470
471
        if prompt_token_ids is not None:
            if (num_requests is not None
                    and num_requests != len(prompt_token_ids)):
                raise ValueError("The lengths of prompts and prompt_token_ids "
                                 "must be the same.")

472
            num_requests = len(prompt_token_ids)
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        if num_requests is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")

        inputs: List[PromptInputs] = []
        for i in range(num_requests):
            if prompts is not None:
                if prompt_token_ids is not None:
                    item = TextTokensPrompt(
                        prompt=prompts[i],
                        prompt_token_ids=prompt_token_ids[i])
                else:
                    item = TextPrompt(prompt=prompts[i])
            else:
                if prompt_token_ids is not None:
                    item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
                else:
                    raise AssertionError

            if multi_modal_data is not None:
                item["multi_modal_data"] = multi_modal_data

            inputs.append(item)

        return inputs

    def _validate_and_add_requests(
        self,
        inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
        params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
                      Sequence[PoolingParams]],
        lora_request: Optional[LoRARequest],
    ) -> None:
        if isinstance(inputs, (str, dict)):
            # Convert a single prompt to a list.
            inputs = [inputs]

        num_requests = len(inputs)
511

512
513
        if isinstance(params, list) and len(params) != num_requests:
            raise ValueError("The lengths of prompts and params "
514
                             "must be the same.")
515

Zhuohan Li's avatar
Zhuohan Li committed
516
        # Add requests to the engine.
517
518
519
520
521
522
        for i, request_inputs in enumerate(inputs):
            self._add_request(
                request_inputs,
                params[i] if isinstance(params, Sequence) else params,
                lora_request=lora_request,
            )
523

524
525
    def _add_request(
        self,
526
        inputs: PromptInputs,
527
        params: Union[SamplingParams, PoolingParams],
528
        lora_request: Optional[LoRARequest] = None,
529
530
    ) -> None:
        request_id = str(next(self.request_counter))
531
        self.llm_engine.add_request(request_id,
532
                                    inputs,
533
                                    params,
534
                                    lora_request=lora_request)
535

536
    def _run_engine(
537
            self, *, use_tqdm: bool
538
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
539
540
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
541
            num_requests = self.llm_engine.get_num_unfinished_requests()
542
543
544
545
546
547
            pbar = tqdm(
                total=num_requests,
                desc="Processed prompts",
                dynamic_ncols=True,
                postfix=f"Generation Speed: {0:.2f} toks/s",
            )
Zhuohan Li's avatar
Zhuohan Li committed
548
        # Run the engine.
549
        outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
550
        total_toks = 0
Zhuohan Li's avatar
Zhuohan Li committed
551
552
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
553
            for output in step_outputs:
554
                if output.finished:
555
556
                    outputs.append(output)
                    if use_tqdm:
557
558
559
560
561
562
                        if isinstance(output, RequestOutput):
                            # Calculate tokens only for RequestOutput
                            total_toks += sum(
                                len(stp.token_ids) for stp in output.outputs)
                            spd = total_toks / pbar.format_dict["elapsed"]
                            pbar.postfix = f"Generation Speed: {spd:.2f} toks/s"
563
564
565
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
566
567
568
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
569
        return sorted(outputs, key=lambda x: int(x.request_id))