config.py 12.5 KB
Newer Older
1
2
3
from typing import Optional

import torch
4
from transformers import PretrainedConfig
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.logger import init_logger
7
from vllm.transformers_utils.config import get_config
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils import get_cpu_memory
9
10
11

logger = init_logger(__name__)

12
_GB = 1 << 30
13

14
15

class ModelConfig:
16
17
18
19
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
20
        tokenizer: Name or path of the huggingface tokenizer to use.
21
22
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
23
24
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
25
26
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
27
28
29
30
31
32
33
34
35
36
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
37
38
39
40
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
41
42
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
43
    """
44
45
46
47

    def __init__(
        self,
        model: str,
48
49
        tokenizer: str,
        tokenizer_mode: str,
50
        trust_remote_code: bool,
51
        download_dir: Optional[str],
52
        load_format: str,
53
54
        dtype: str,
        seed: int,
55
        max_model_len: Optional[int] = None,
56
57
    ) -> None:
        self.model = model
58
        self.tokenizer = tokenizer
59
        self.tokenizer_mode = tokenizer_mode
60
        self.trust_remote_code = trust_remote_code
61
        self.download_dir = download_dir
62
        self.load_format = load_format
63
64
        self.seed = seed

65
        self.hf_config = get_config(model, trust_remote_code)
66
        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
67
        self._verify_load_format()
68
        self._verify_tokenizer_mode()
69
70
71
72
73
74
75
76
77
78
        self.max_model_len = None
        if max_model_len is not None:
            derived_max_model_len = self.get_max_model_len()
            if max_model_len > derived_max_model_len:
                logger.warning(
                    f"User-specified max_model_len ({max_model_len}) is "
                    f"greater than the derived max_model_len "
                    f"({derived_max_model_len}). Make sure the value is "
                    "correct and within the model context size.")
        self.max_model_len = max_model_len
79

80
81
82
83
84
85
86
87
88
89
    def _verify_load_format(self) -> None:
        load_format = self.load_format.lower()
        if load_format not in [
                "auto", "pt", "safetensors", "npcache", "dummy"
        ]:
            raise ValueError(
                f"Unknown load format: {self.load_format}. Must be one of "
                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
        self.load_format = load_format

90
91
92
93
94
95
96
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_num_attention_heads = self.hf_config.num_attention_heads
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        total_num_hidden_layers = self.hf_config.num_hidden_layers
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

    def get_hidden_size(self) -> int:
        return self.hf_config.hidden_size

    def get_head_size(self) -> int:
        # FIXME(woosuk): This may not be true for all models.
        return self.hf_config.hidden_size // self.hf_config.num_attention_heads

    def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
Zhuohan Li's avatar
Zhuohan Li committed
126
127
128
129
        # For GPTBigCode & Falcon:
        # Note: for falcon, when new_decoder_architecture is True, the
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
130
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
131
        new_decoder_arch_falcon = (
132
            self.hf_config.model_type in falcon_model_types
133
134
135
            and getattr(self.hf_config, "new_decoder_architecture", False))
        if not new_decoder_arch_falcon and getattr(self.hf_config,
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
136
137
138
139
            # Multi-query attention, only one KV head.
            return 1
        # For Falcon:
        if getattr(self.hf_config, "n_head_kv", None) is not None:
Zhuohan Li's avatar
Zhuohan Li committed
140
141
142
143
144
145
            return (self.hf_config.n_head_kv //
                    parallel_config.tensor_parallel_size)
        # For LLaMA-2:
        if getattr(self.hf_config, "num_key_value_heads", None) is not None:
            return (self.hf_config.num_key_value_heads //
                    parallel_config.tensor_parallel_size)
146
147
148
        total_num_attention_heads = self.hf_config.num_attention_heads
        return total_num_attention_heads // parallel_config.tensor_parallel_size

149
    def get_max_model_len(self) -> int:
150
151
        if self.max_model_len is not None:
            return self.max_model_len
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        max_model_len = float("inf")
        possible_keys = [
            # OPT
            "max_position_embeddings",
            # GPT-2
            "n_positions",
            # MPT
            "max_seq_len",
            # Others
            "max_sequence_length",
            "max_seq_length",
            "seq_len",
        ]
        for key in possible_keys:
            max_len_key = getattr(self.hf_config, key, None)
            if max_len_key is not None:
                max_model_len = min(max_model_len, max_len_key)
        return max_model_len

171
172
173
174
175
176
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
        total_num_hidden_layers = self.hf_config.num_hidden_layers
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
177
178
179
180
181
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
182
            vLLM execution.
183
184
        swap_space: Size of the CPU swap space per GPU (in GiB).
    """
185

186
187
188
189
190
191
192
193
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
194
        self.swap_space_bytes = swap_space * _GB
195
        self._verify_args()
196
197
198
199
200

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

217
218
219
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
220
221
222
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
223
            logger.warning("Possibly too large swap space. " + msg)
224

225
226

class ParallelConfig:
227
228
229
230
231
232
233
234
235
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
        worker_use_ray: Whether to use Ray for model workers. Will be set to
            True if either pipeline_parallel_size or tensor_parallel_size is
            greater than 1.
    """
236

237
238
239
240
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
241
        worker_use_ray: bool,
242
243
244
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
        self.tensor_parallel_size = tensor_parallel_size
245
        self.worker_use_ray = worker_use_ray
246
247
248

        self.world_size = pipeline_parallel_size * tensor_parallel_size
        if self.world_size > 1:
249
            self.worker_use_ray = True
250
251
252
253
254
255
256
257
258
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")


class SchedulerConfig:
259
260
261
262
263
264
265
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
266
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
267
            and generated text).
268
    """
269
270

    def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
Lily Liu's avatar
Lily Liu committed
271
                 max_model_len: int) -> None:
272
273
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
274
        self.max_model_len = max_model_len
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: str,
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

    dtype = dtype.lower()
Woosuk Kwon's avatar
Woosuk Kwon committed
297
    if dtype == "auto":
298
299
300
301
302
303
        if config_dtype == torch.float32:
            # Following the common practice, we use float16 for float32 models.
            torch_dtype = torch.float16
        else:
            torch_dtype = config_dtype
    else:
304
305
        if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
            raise ValueError(f"Unknown dtype: {dtype}")
306
307
308
309
310
311
312
313
314
315
316
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
317
            # Casting between float16 and bfloat16 is allowed with a warning.
318
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
319
320
321
322
323
324
325
326
327
328
329

    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
                f"{compute_capability[0]}.{compute_capability[1]}.")
    return torch_dtype