config.py 10.5 KB
Newer Older
1
2
3
from typing import Optional

import torch
4
from transformers import PretrainedConfig
5

Woosuk Kwon's avatar
Woosuk Kwon committed
6
from vllm.logger import init_logger
7
from vllm.transformers_utils.config import get_config
Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.utils import get_cpu_memory
9
10
11

logger = init_logger(__name__)

12
_GB = 1 << 30
13

14
15

class ModelConfig:
16
17
18
19
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
20
        tokenizer: Name or path of the huggingface tokenizer to use.
21
22
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
23
24
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
25
26
27
28
29
30
31
32
33
34
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
        use_np_weights: Save a numpy copy of model weights for faster loading.
            This can increase the disk usage by up to 2x.
        use_dummy_weights: Use dummy values for model weights (for profiling).
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
    """
35
36
37
38

    def __init__(
        self,
        model: str,
39
40
        tokenizer: str,
        tokenizer_mode: str,
41
        trust_remote_code: bool,
42
43
44
45
46
47
48
        download_dir: Optional[str],
        use_np_weights: bool,
        use_dummy_weights: bool,
        dtype: str,
        seed: int,
    ) -> None:
        self.model = model
49
        self.tokenizer = tokenizer
50
        self.tokenizer_mode = tokenizer_mode
51
        self.trust_remote_code = trust_remote_code
52
53
54
55
56
        self.download_dir = download_dir
        self.use_np_weights = use_np_weights
        self.use_dummy_weights = use_dummy_weights
        self.seed = seed

57
        self.hf_config = get_config(model, trust_remote_code)
58
        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
59
60
61
62
63
64
65
66
67
        self._verify_tokenizer_mode()

    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96

    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_num_attention_heads = self.hf_config.num_attention_heads
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        total_num_hidden_layers = self.hf_config.num_hidden_layers
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

    def get_hidden_size(self) -> int:
        return self.hf_config.hidden_size

    def get_head_size(self) -> int:
        # FIXME(woosuk): This may not be true for all models.
        return self.hf_config.hidden_size // self.hf_config.num_attention_heads

    def get_num_heads(self, parallel_config: "ParallelConfig") -> int:
Zhuohan Li's avatar
Zhuohan Li committed
97
98
99
100
101
102
        # For GPTBigCode:
        if getattr(self.hf_config, "multi_query", False):
            # Multi-query attention, only one KV head.
            return 1
        # For Falcon:
        if getattr(self.hf_config, "n_head_kv", None) is not None:
Zhuohan Li's avatar
Zhuohan Li committed
103
104
105
106
107
108
            return (self.hf_config.n_head_kv //
                    parallel_config.tensor_parallel_size)
        # For LLaMA-2:
        if getattr(self.hf_config, "num_key_value_heads", None) is not None:
            return (self.hf_config.num_key_value_heads //
                    parallel_config.tensor_parallel_size)
109
110
111
        total_num_attention_heads = self.hf_config.num_attention_heads
        return total_num_attention_heads // parallel_config.tensor_parallel_size

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
    def get_max_model_len(self) -> int:
        max_model_len = float("inf")
        possible_keys = [
            # OPT
            "max_position_embeddings",
            # GPT-2
            "n_positions",
            # MPT
            "max_seq_len",
            # Others
            "max_sequence_length",
            "max_seq_length",
            "seq_len",
        ]
        for key in possible_keys:
            max_len_key = getattr(self.hf_config, key, None)
            if max_len_key is not None:
                max_model_len = min(max_model_len, max_len_key)
        return max_model_len

132
133
134
135
136
137
    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
        total_num_hidden_layers = self.hf_config.num_hidden_layers
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
138
139
140
141
142
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
143
            vLLM execution.
144
145
        swap_space: Size of the CPU swap space per GPU (in GiB).
    """
146

147
148
149
150
151
152
153
154
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
155
        self.swap_space_bytes = swap_space * _GB
156
        self._verify_args()
157
158
159
160
161

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

178
179
180
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
181
182
183
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
184
            logger.warning("Possibly too large swap space. " + msg)
185

186
187

class ParallelConfig:
188
189
190
191
192
193
194
195
196
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
        worker_use_ray: Whether to use Ray for model workers. Will be set to
            True if either pipeline_parallel_size or tensor_parallel_size is
            greater than 1.
    """
197

198
199
200
201
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
202
        worker_use_ray: bool,
203
204
205
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
        self.tensor_parallel_size = tensor_parallel_size
206
        self.worker_use_ray = worker_use_ray
207
208
209

        self.world_size = pipeline_parallel_size * tensor_parallel_size
        if self.world_size > 1:
210
            self.worker_use_ray = True
211
212
213
214
215
216
217
218
219
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")


class SchedulerConfig:
220
221
222
223
224
225
226
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Lily Liu's avatar
Lily Liu committed
227
228
        max_seq_len: Maximum length of a sequence (including prompt
            and generated text).
229
    """
230
231

    def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
Lily Liu's avatar
Lily Liu committed
232
                 max_model_len: int) -> None:
233
234
        self.max_num_batched_tokens = max_num_batched_tokens
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
235
        self.max_model_len = max_model_len
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}


def _get_and_verify_dtype(
    config: PretrainedConfig,
    dtype: str,
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

    dtype = dtype.lower()
Woosuk Kwon's avatar
Woosuk Kwon committed
258
    if dtype == "auto":
259
260
261
262
263
264
        if config_dtype == torch.float32:
            # Following the common practice, we use float16 for float32 models.
            torch_dtype = torch.float16
        else:
            torch_dtype = config_dtype
    else:
265
266
        if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
            raise ValueError(f"Unknown dtype: {dtype}")
267
268
269
270
271
272
273
274
275
276
277
        torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]

    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
278
            # Casting between float16 and bfloat16 is allowed with a warning.
279
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
280
281
282
283
284
285
286
287
288
289
290

    # Check if the GPU supports the dtype.
    if torch_dtype == torch.bfloat16:
        compute_capability = torch.cuda.get_device_capability()
        if compute_capability[0] < 8:
            gpu_name = torch.cuda.get_device_name()
            raise ValueError(
                "Bfloat16 is only supported on GPUs with compute capability "
                f"of at least 8.0. Your {gpu_name} GPU has compute capability "
                f"{compute_capability[0]}.{compute_capability[1]}.")
    return torch_dtype