llm.py 8.71 KB
Newer Older
1
from typing import List, Optional, Union
2
3

from tqdm import tqdm
Zhuohan Li's avatar
Zhuohan Li committed
4
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
7
8
9
10
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.utils import Counter
11
12
13


class LLM:
Woosuk Kwon's avatar
Woosuk Kwon committed
14
15
16
17
18
19
20
21
22
    """An LLM for generating texts from given prompts and sampling parameters.

    This class includes a tokenizer, a language model (possibly distributed
    across multiple GPUs), and GPU memory space allocated for intermediate
    states (aka KV cache). Given a batch of prompts and sampling parameters,
    this class generates texts from the model, using an intelligent batching
    mechanism and efficient memory management.

    NOTE: This class is intended to be used for offline inference. For online
23
    serving, use the `AsyncLLMEngine` class instead.
Zhuohan Li's avatar
Zhuohan Li committed
24
    NOTE: For the comprehensive list of arguments, see `EngineArgs`.
Woosuk Kwon's avatar
Woosuk Kwon committed
25
26
27

    Args:
        model: The name or path of a HuggingFace Transformers model.
28
        tokenizer: The name or path of a HuggingFace Transformers tokenizer.
29
30
        tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
            if available, and "slow" will always use the slow tokenizer.
31
32
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
Woosuk Kwon's avatar
Woosuk Kwon committed
33
34
35
        tensor_parallel_size: The number of GPUs to use for distributed
            execution with tensor parallelism.
        dtype: The data type for the model weights and activations. Currently,
Woosuk Kwon's avatar
Woosuk Kwon committed
36
37
38
39
            we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
            the `torch_dtype` attribute specified in the model config file.
            However, if the `torch_dtype` in the config is `float32`, we will
            use `float16` instead.
40
        quantization: The method used to quantize the model weights. Currently,
CHU Tianxiang's avatar
CHU Tianxiang committed
41
42
43
            we support "awq", "gptq" and "squeezellm". If None, we assume the
            model weights are not quantized and use `dtype` to determine the
            data type of the weights.
Jasmond L's avatar
Jasmond L committed
44
45
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id.
46
47
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id.
48
49
50
51
52
53
54
55
56
57
58
        seed: The seed to initialize the random number generator for sampling.
        gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
            reserve for the model weights, activations, and KV cache. Higher
            values will increase the KV cache size and thus improve the model's
            throughput. However, if the value is too high, it may cause out-of-
            memory (OOM) errors.
        swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
            This can be used for temporarily storing the states of the requests
            when their `best_of` sampling parameters are larger than 1. If all
            requests will have `best_of=1`, you can safely set this to 0.
            Otherwise, too small values may cause out-of-memory (OOM) errors.
59
60
61
62
63
64
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
Woosuk Kwon's avatar
Woosuk Kwon committed
65
    """
66
67
68
69

    def __init__(
        self,
        model: str,
70
        tokenizer: Optional[str] = None,
71
        tokenizer_mode: str = "auto",
72
        trust_remote_code: bool = False,
73
        tensor_parallel_size: int = 1,
Woosuk Kwon's avatar
Woosuk Kwon committed
74
        dtype: str = "auto",
75
        quantization: Optional[str] = None,
76
        revision: Optional[str] = None,
77
        tokenizer_revision: Optional[str] = None,
78
79
80
        seed: int = 0,
        gpu_memory_utilization: float = 0.9,
        swap_space: int = 4,
81
82
        enforce_eager: bool = False,
        max_context_len_to_capture: int = 8192,
83
84
85
86
        **kwargs,
    ) -> None:
        if "disable_log_stats" not in kwargs:
            kwargs["disable_log_stats"] = True
Zhuohan Li's avatar
Zhuohan Li committed
87
        engine_args = EngineArgs(
88
            model=model,
89
            tokenizer=tokenizer,
90
            tokenizer_mode=tokenizer_mode,
91
            trust_remote_code=trust_remote_code,
92
93
            tensor_parallel_size=tensor_parallel_size,
            dtype=dtype,
94
            quantization=quantization,
95
            revision=revision,
96
            tokenizer_revision=tokenizer_revision,
97
98
99
            seed=seed,
            gpu_memory_utilization=gpu_memory_utilization,
            swap_space=swap_space,
100
101
            enforce_eager=enforce_eager,
            max_context_len_to_capture=max_context_len_to_capture,
102
103
            **kwargs,
        )
Zhuohan Li's avatar
Zhuohan Li committed
104
        self.llm_engine = LLMEngine.from_engine_args(engine_args)
105
106
        self.request_counter = Counter()

107
    def get_tokenizer(
108
            self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
Zhuohan Li's avatar
Zhuohan Li committed
109
        return self.llm_engine.tokenizer
110

111
112
113
114
115
116
    def set_tokenizer(
        self,
        tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
    ) -> None:
        self.llm_engine.tokenizer = tokenizer

117
118
    def generate(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
119
        prompts: Optional[Union[str, List[str]]] = None,
120
        sampling_params: Optional[SamplingParams] = None,
121
        prompt_token_ids: Optional[List[List[int]]] = None,
122
123
        use_tqdm: bool = True,
    ) -> List[RequestOutput]:
Woosuk Kwon's avatar
Woosuk Kwon committed
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        """Generates the completions for the input prompts.

        NOTE: This class automatically batches the given prompts, considering
        the memory constraint. For the best performance, put all of your prompts
        into a single list and pass it to this method.

        Args:
            prompts: A list of prompts to generate completions for.
            sampling_params: The sampling parameters for text generation. If
                None, we use the default sampling parameters.
            prompt_token_ids: A list of token IDs for the prompts. If None, we
                use the tokenizer to convert the prompts to token IDs.
            use_tqdm: Whether to use tqdm to display the progress bar.

        Returns:
            A list of `RequestOutput` objects containing the generated
            completions in the same order as the input prompts.
        """
        if prompts is None and prompt_token_ids is None:
            raise ValueError("Either prompts or prompt_token_ids must be "
                             "provided.")
Woosuk Kwon's avatar
Woosuk Kwon committed
145
        if isinstance(prompts, str):
Woosuk Kwon's avatar
Woosuk Kwon committed
146
            # Convert a single prompt to a list.
Woosuk Kwon's avatar
Woosuk Kwon committed
147
            prompts = [prompts]
148
149
150
151
        if (prompts is not None and prompt_token_ids is not None
                and len(prompts) != len(prompt_token_ids)):
            raise ValueError("The lengths of prompts and prompt_token_ids "
                             "must be the same.")
152
        if sampling_params is None:
153
            # Use default sampling params.
154
            sampling_params = SamplingParams()
Woosuk Kwon's avatar
Woosuk Kwon committed
155

Zhuohan Li's avatar
Zhuohan Li committed
156
        # Add requests to the engine.
157
158
        num_requests = len(prompts) if prompts is not None else len(
            prompt_token_ids)
Woosuk Kwon's avatar
Woosuk Kwon committed
159
160
        for i in range(num_requests):
            prompt = prompts[i] if prompts is not None else None
161
162
            token_ids = None if prompt_token_ids is None else prompt_token_ids[
                i]
163
            self._add_request(prompt, sampling_params, token_ids)
Zhuohan Li's avatar
Zhuohan Li committed
164
        return self._run_engine(use_tqdm)
165

166
167
    def _add_request(
        self,
Woosuk Kwon's avatar
Woosuk Kwon committed
168
        prompt: Optional[str],
169
170
171
172
        sampling_params: SamplingParams,
        prompt_token_ids: Optional[List[int]],
    ) -> None:
        request_id = str(next(self.request_counter))
Zhuohan Li's avatar
Zhuohan Li committed
173
        self.llm_engine.add_request(request_id, prompt, sampling_params,
174
175
                                    prompt_token_ids)

Zhuohan Li's avatar
Zhuohan Li committed
176
    def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
177
178
        # Initialize tqdm.
        if use_tqdm:
Zhuohan Li's avatar
Zhuohan Li committed
179
            num_requests = self.llm_engine.get_num_unfinished_requests()
180
            pbar = tqdm(total=num_requests, desc="Processed prompts")
Zhuohan Li's avatar
Zhuohan Li committed
181
        # Run the engine.
182
        outputs: List[RequestOutput] = []
Zhuohan Li's avatar
Zhuohan Li committed
183
184
        while self.llm_engine.has_unfinished_requests():
            step_outputs = self.llm_engine.step()
185
            for output in step_outputs:
186
                if output.finished:
187
188
189
190
191
                    outputs.append(output)
                    if use_tqdm:
                        pbar.update(1)
        if use_tqdm:
            pbar.close()
192
193
194
195
        # Sort the outputs by request ID.
        # This is necessary because some requests may be finished earlier than
        # its previous requests.
        outputs = sorted(outputs, key=lambda x: int(x.request_id))
196
        return outputs