llm.py 10.1 KB
Newer Older
1
from typing import List, Optional, Union
2
3

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
4
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5

6
from vllm.lora.request import LoRARequest
Woosuk Kwon's avatar
Woosuk Kwon committed
7
8
9
10
11
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import Counter
12
13
14


class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
15
16
17
18
19
20
21
22
23
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    NOTE: This class is intended to be used for offline inference. For online
24
    serving, use the `AsyncLLMEngine` class instead.
Zhuohan Li's avatar
Zhuohan Li committed
25
    NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Woosuk Kwon's avatar
Woosuk Kwon committed
26
27
28

    Args:
        model: The name or path of a HuggingFace Transformers model.
29
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
30
31
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
32
33
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
34
35
36
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
37
38
39
40
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
41
        quantization: The method used to quantize the model weights. Currently,
42
43
44
45
            we support "awq", "gptq" and "squeezellm". If None, we first check
            the `quantization_config` attribute in the model config file. If
            that is None, we assume the model weights are not quantized and use
            `dtype` to determine the data type of the weights.
Jasmond L's avatar
Jasmond L committed
46
47
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
48
49
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
50
51
52
53
54
55
56
57
58
59
60
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
61
62
63
64
65
66
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
67
        disable_custom_all_reduce: See ParallelConfig
Woosuk Kwon's avatar
Woosuk Kwon committed
68
    """
69
70
71
72

    def __init__(
        self,
        model: str,
73
        tokenizer: Optional[str] = None,
74
        tokenizer_mode: str = "auto",
75
        trust_remote_code: bool = False,
76
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
77
        dtype: str = "auto",
78
        quantization: Optional[str] = None,
79
        revision: Optional[str] = None,
80
        tokenizer_revision: Optional[str] = None,
81
82
83
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
84
85
        enforce_eager: bool = False,
        max_context_len_to_capture: int = 8192,
86
        disable_custom_all_reduce: bool = False,
87
88
89
90
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
Zhuohan Li's avatar
Zhuohan Li committed
91
        engine_args = EngineArgs(
92
            model=model,
93
            tokenizer=tokenizer,
94
            tokenizer_mode=tokenizer_mode,
95
            trust_remote_code=trust_remote_code,
96
97
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
98
            quantization=quantization,
99
            revision=revision,
100
            tokenizer_revision=tokenizer_revision,
101
102
103
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
104
105
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
106
            disable_custom_all_reduce=disable_custom_all_reduce,
107
108
            **kwargs,
        )
Zhuohan Li's avatar
Zhuohan Li committed
109
        self.llm_engine = LLMEngine.from_engine_args(engine_args)
110
111
        self.request_counter = Counter()

112
    def get_tokenizer(
113
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
Zhuohan Li's avatar
Zhuohan Li committed
114
        return self.llm_engine.tokenizer
115

116
117
118
119
120
121
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
        self.llm_engine.tokenizer = tokenizer

122
123
    def generate(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
124
        prompts: Optional[Union[str, List[str]]] = None,
125
        sampling_params: Optional[SamplingParams] = None,
126
        prompt_token_ids: Optional[List[List[int]]] = None,
127
        prefix_pos: Optional[Union[int, List[int]]] = None,
128
        use_tqdm: bool = True,
129
        lora_request: Optional[LoRARequest] = None,
130
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
131
132
133
134
135
136
137
138
139
140
141
142
        """Generates the completions for the input prompts.

        NOTE: This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: A list of prompts to generate completions for.
            sampling_params: The sampling parameters for text generation. If
                None, we use the default sampling parameters.
            prompt_token_ids: A list of token IDs for the prompts. If None, we
                use the tokenizer to convert the prompts to token IDs.
143
144
145
146
147
            prefix_pos: If not None, we use the given position as the prefix
                position for each prompt. We will cache the prefix's KV
                cache and reuse it for the next request with the same prefix.
                This is an experimental feature, and may be replaced with
                automatic prefix caching in the future.
Woosuk Kwon's avatar
Woosuk Kwon committed
148
            use_tqdm: Whether to use tqdm to display the progress bar.
149
            lora_request: LoRA request to use for generation, if any.
Woosuk Kwon's avatar
Woosuk Kwon committed
150
151
152
153
154
155
156
157

        Returns:
            A list of `RequestOutput` objects containing the generated
            completions in the same order as the input prompts.
        """
        if prompts is None and prompt_token_ids is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")
Woosuk Kwon's avatar
Woosuk Kwon committed
158
        if isinstance(prompts, str):
Woosuk Kwon's avatar
Woosuk Kwon committed
159
            # Convert a single prompt to a list.
Woosuk Kwon's avatar
Woosuk Kwon committed
160
            prompts = [prompts]
161
162
163
164
        if (prompts is not None and prompt_token_ids is not None
                and len(prompts) != len(prompt_token_ids)):
            raise ValueError("The lengths of prompts and prompt_token_ids "
                             "must be the same.")
165
        if sampling_params is None:
166
            # Use default sampling params.
167
            sampling_params = SamplingParams()
Woosuk Kwon's avatar
Woosuk Kwon committed
168

Zhuohan Li's avatar
Zhuohan Li committed
169
        # Add requests to the engine.
170
171
        num_requests = len(prompts) if prompts is not None else len(
            prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
172
173
        for i in range(num_requests):
            prompt = prompts[i] if prompts is not None else None
174
            prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
175
176
            token_ids = None if prompt_token_ids is None else prompt_token_ids[
                i]
177
178
179
180
181
            self._add_request(prompt,
                              sampling_params,
                              token_ids,
                              lora_request=lora_request,
                              prefix_pos=prefix_pos_i)
Zhuohan Li's avatar
Zhuohan Li committed
182
        return self._run_engine(use_tqdm)
183

184
185
    def _add_request(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
186
        prompt: Optional[str],
187
188
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
189
        lora_request: Optional[LoRARequest] = None,
190
        prefix_pos: Optional[int] = None,
191
192
    ) -> None:
        request_id = str(next(self.request_counter))
193
194
195
196
        self.llm_engine.add_request(request_id,
                                    prompt,
                                    sampling_params,
                                    prompt_token_ids,
197
                                    lora_request=lora_request,
198
                                    prefix_pos=prefix_pos)
199

Zhuohan Li's avatar
Zhuohan Li committed
200
    def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
201
202
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
203
            num_requests = self.llm_engine.get_num_unfinished_requests()
204
            pbar = tqdm(total=num_requests, desc="Processed prompts")
Zhuohan Li's avatar
Zhuohan Li committed
205
        # Run the engine.
206
        outputs: List[RequestOutput] = []
Zhuohan Li's avatar
Zhuohan Li committed
207
208
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
209
            for output in step_outputs:
210
                if output.finished:
211
212
213
214
215
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
216
217
218
219
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
        outputs = sorted(outputs, key=lambda x: int(x.request_id))
220
        return outputs