llm.py 10.4 KB
Newer Older
1
from typing import List, Optional, Union
2

3
import torch
4
from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
5
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
9
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
12
from vllm.sequence import MultiModalData
yhu422's avatar
yhu422 committed
13
from vllm.usage.usage_lib import UsageContext
Woosuk Kwon's avatar
Woosuk Kwon committed
14
from vllm.utils import Counter
15
16
17


class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
18
19
20
21
22
23
24
25
26
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    NOTE: This class is intended to be used for offline inference. For online
27
    serving, use the `AsyncLLMEngine` class instead.
Zhuohan Li's avatar
Zhuohan Li committed
28
    NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Woosuk Kwon's avatar
Woosuk Kwon committed
29
30
31

    Args:
        model: The name or path of a HuggingFace Transformers model.
32
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
33
34
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
35
36
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
39
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
40
41
42
43
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
44
        quantization: The method used to quantize the model weights. Currently,
45
46
47
48
49
            we support "awq", "gptq", "squeezellm", and "fp8" (experimental).
            If None, we first check the `quantization_config` attribute in the
            model config file. If that is None, we assume the model weights are
            not quantized and use `dtype` to determine the data type of
            the weights.
Jasmond L's avatar
Jasmond L committed
50
51
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
52
53
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
54
55
56
57
58
59
60
61
62
63
64
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
65
66
67
68
69
70
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
71
        disable_custom_all_reduce: See ParallelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
72
    """
73
74
75
76

    def __init__(
        self,
        model: str,
77
        tokenizer: Optional[str] = None,
78
        tokenizer_mode: str = "auto",
79
        trust_remote_code: bool = False,
80
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
81
        dtype: str = "auto",
82
        quantization: Optional[str] = None,
83
        revision: Optional[str] = None,
84
        tokenizer_revision: Optional[str] = None,
85
86
87
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
88
89
        enforce_eager: bool = False,
        max_context_len_to_capture: int = 8192,
90
        disable_custom_all_reduce: bool = False,
91
92
93
94
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
Zhuohan Li's avatar
Zhuohan Li committed
95
        engine_args = EngineArgs(
96
            model=model,
97
            tokenizer=tokenizer,
98
            tokenizer_mode=tokenizer_mode,
99
            trust_remote_code=trust_remote_code,
100
101
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
102
            quantization=quantization,
103
            revision=revision,
104
            tokenizer_revision=tokenizer_revision,
105
106
107
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
108
109
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
110
            disable_custom_all_reduce=disable_custom_all_reduce,
111
112
            **kwargs,
        )
yhu422's avatar
yhu422 committed
113
114
        self.llm_engine = LLMEngine.from_engine_args(
            engine_args, usage_context=UsageContext.LLM_CLASS)
115
116
        self.request_counter = Counter()

117
    def get_tokenizer(
118
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
119
        return self.llm_engine.tokenizer.tokenizer
120

121
122
123
124
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
125
        self.llm_engine.tokenizer.tokenizer = tokenizer
126

127
128
    def generate(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
129
        prompts: Optional[Union[str, List[str]]] = None,
130
        sampling_params: Optional[SamplingParams] = None,
131
        prompt_token_ids: Optional[List[List[int]]] = None,
132
        use_tqdm: bool = True,
133
        lora_request: Optional[LoRARequest] = None,
134
        multi_modal_data: Optional[MultiModalData] = None,
135
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
136
137
138
139
140
141
142
143
144
145
146
147
148
        """Generates the completions for the input prompts.

        NOTE: This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: A list of prompts to generate completions for.
            sampling_params: The sampling parameters for text generation. If
                None, we use the default sampling parameters.
            prompt_token_ids: A list of token IDs for the prompts. If None, we
                use the tokenizer to convert the prompts to token IDs.
            use_tqdm: Whether to use tqdm to display the progress bar.
149
            lora_request: LoRA request to use for generation, if any.
150
            multi_modal_data: Multi modal data.
Woosuk Kwon's avatar
Woosuk Kwon committed
151
152
153
154
155
156
157
158

        Returns:
            A list of `RequestOutput` objects containing the generated
            completions in the same order as the input prompts.
        """
        if prompts is None and prompt_token_ids is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")
Woosuk Kwon's avatar
Woosuk Kwon committed
159
        if isinstance(prompts, str):
Woosuk Kwon's avatar
Woosuk Kwon committed
160
            # Convert a single prompt to a list.
Woosuk Kwon's avatar
Woosuk Kwon committed
161
            prompts = [prompts]
162
163
164
165
        if (prompts is not None and prompt_token_ids is not None
                and len(prompts) != len(prompt_token_ids)):
            raise ValueError("The lengths of prompts and prompt_token_ids "
                             "must be the same.")
166
        if sampling_params is None:
167
            # Use default sampling params.
168
            sampling_params = SamplingParams()
Woosuk Kwon's avatar
Woosuk Kwon committed
169

170
171
172
        if multi_modal_data:
            multi_modal_data.data = multi_modal_data.data.to(torch.float16)

Zhuohan Li's avatar
Zhuohan Li committed
173
        # Add requests to the engine.
174
175
176
177
178
179
        if prompts is not None:
            num_requests = len(prompts)
        else:
            assert prompt_token_ids is not None
            num_requests = len(prompt_token_ids)

Woosuk Kwon's avatar
Woosuk Kwon committed
180
181
        for i in range(num_requests):
            prompt = prompts[i] if prompts is not None else None
182
183
            token_ids = None if prompt_token_ids is None else prompt_token_ids[
                i]
184
185
186
187
188
189
190
191
192
193
194
            self._add_request(
                prompt,
                sampling_params,
                token_ids,
                lora_request=lora_request,
                # Get ith image while maintaining the batch dim.
                multi_modal_data=MultiModalData(
                    type=multi_modal_data.type,
                    data=multi_modal_data.data[i].unsqueeze(0))
                if multi_modal_data else None,
            )
Zhuohan Li's avatar
Zhuohan Li committed
195
        return self._run_engine(use_tqdm)
196

197
198
    def _add_request(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
199
        prompt: Optional[str],
200
201
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
202
        lora_request: Optional[LoRARequest] = None,
203
        multi_modal_data: Optional[MultiModalData] = None,
204
205
    ) -> None:
        request_id = str(next(self.request_counter))
206
207
208
209
        self.llm_engine.add_request(request_id,
                                    prompt,
                                    sampling_params,
                                    prompt_token_ids,
210
211
                                    lora_request=lora_request,
                                    multi_modal_data=multi_modal_data)
212

Zhuohan Li's avatar
Zhuohan Li committed
213
    def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
214
215
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
216
            num_requests = self.llm_engine.get_num_unfinished_requests()
217
218
219
            pbar = tqdm(total=num_requests,
                        desc="Processed prompts",
                        dynamic_ncols=True)
Zhuohan Li's avatar
Zhuohan Li committed
220
        # Run the engine.
221
        outputs: List[RequestOutput] = []
Zhuohan Li's avatar
Zhuohan Li committed
222
223
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
224
            for output in step_outputs:
225
                if output.finished:
226
227
228
229
230
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
231
232
233
234
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
        outputs = sorted(outputs, key=lambda x: int(x.request_id))
235
        return outputs