config.py 20.3 KB
Newer Older
1
from typing import Optional, Union
2
import os
3
4

import torch
5
from transformers import PretrainedConfig
6

Woosuk Kwon's avatar
Woosuk Kwon committed
7
from vllm.logger import init_logger
8
from vllm.transformers_utils.config import get_config
9
from vllm.utils import get_cpu_memory, is_hip
10
11
12

logger = init_logger(__name__)

13
_GB = 1 << 30
14

15
16

class ModelConfig:
17
18
19
20
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
21
        tokenizer: Name or path of the huggingface tokenizer to use.
22
23
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
24
25
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
26
27
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
28
29
30
31
32
33
34
35
36
37
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
38
39
40
41
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
42
43
44
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
45
46
47
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
48
49
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
50
51
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
52
    """
53
54
55
56

    def __init__(
        self,
        model: str,
57
58
        tokenizer: str,
        tokenizer_mode: str,
59
        trust_remote_code: bool,
60
        download_dir: Optional[str],
61
        load_format: str,
62
        dtype: Union[str, torch.dtype],
63
        seed: int,
64
        revision: Optional[str] = None,
65
        tokenizer_revision: Optional[str] = None,
66
        max_model_len: Optional[int] = None,
67
        quantization: Optional[str] = None,
68
69
    ) -> None:
        self.model = model
70
        self.tokenizer = tokenizer
71
        self.tokenizer_mode = tokenizer_mode
72
        self.trust_remote_code = trust_remote_code
73
        self.download_dir = download_dir
74
        self.load_format = load_format
75
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
76
        self.revision = revision
77
        self.tokenizer_revision = tokenizer_revision
78
        self.quantization = quantization
79

80
81
82
83
84
85
86
87
88
89
90
91
        if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
            model_path = snapshot_download(model_id=model,
                                           cache_dir=download_dir,
                                           revision=revision)
            self.model = model_path
            self.download_dir = model_path
            self.tokenizer = model_path

        self.hf_config = get_config(self.model, trust_remote_code, revision)
92
        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
93
94
        self.max_model_len = _get_and_verify_max_len(self.hf_config,
                                                     max_model_len)
95
        self._verify_load_format()
96
        self._verify_tokenizer_mode()
97
        self._verify_quantization()
98

99
100
    def _verify_load_format(self) -> None:
        load_format = self.load_format.lower()
101
102
103
104
105
        supported_load_format = [
            "auto", "pt", "safetensors", "npcache", "dummy"
        ]
        rocm_not_supported_load_format = ["safetensors"]
        if load_format not in supported_load_format:
106
107
108
            raise ValueError(
                f"Unknown load format: {self.load_format}. Must be one of "
                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
109
110
111
112
113
114
115
116
117
118
119
120
121
        if is_hip():
            if load_format in ["safetensors"]:
                rocm_supported_load_format = [
                    f for f in supported_load_format
                    if (f not in rocm_not_supported_load_format)
                ]
                raise ValueError(
                    f"load format \'{load_format}\' is not supported in ROCm. "
                    f"Supported load format are "
                    f"{rocm_supported_load_format}")
            # Force ROCm to load from pt weights if nothing specific is set
            if load_format == "auto":
                load_format = "pt"
122
123
124
125
126
127
128
129
130
131

        # FIXME(woosuk): This is a temporary hack. Support safetensor weights.
        architectures = getattr(self.hf_config, "architectures", [])
        if "MixtralForCausalLM" in architectures and load_format != "pt":
            logger.info(
                "Currently, only 'pt' format is supported for Mixtral. "
                "Changing the format to 'pt'. This may re-download the "
                "weights if you have downloaded the safetensor weights.")
            load_format = "pt"

132
133
        self.load_format = load_format

134
135
136
137
138
139
140
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
141

142
    def _verify_quantization(self) -> None:
chooper1's avatar
chooper1 committed
143
        supported_quantization = ["awq", "squeezellm"]
144
        rocm_not_supported_quantization = ["awq"]
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        hf_quant_config = getattr(self.hf_config, "quantization_config", None)
        if hf_quant_config is not None:
            hf_quant_method = str(hf_quant_config["quant_method"]).lower()
            if self.quantization is None:
                self.quantization = hf_quant_method
            elif self.quantization != hf_quant_method:
                raise ValueError(
                    "Quantization method specified in the model config "
                    f"({hf_quant_method}) does not match the quantization "
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
166
167
168
169
170
            if is_hip(
            ) and self.quantization in rocm_not_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not supported "
                    f"in ROCm.")
171
172
173
            logger.warning(f"{self.quantization} quantization is not fully "
                           "optimized yet. The speed can be slower than "
                           "non-quantized models.")
174

175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_num_attention_heads = self.hf_config.num_attention_heads
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        total_num_hidden_layers = self.hf_config.num_hidden_layers
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

195
196
197
198
199
200
    def get_sliding_window(self) -> Optional[int]:
        return getattr(self.hf_config, "sliding_window", None)

    def get_vocab_size(self) -> int:
        return self.hf_config.vocab_size

201
202
203
204
205
206
207
    def get_hidden_size(self) -> int:
        return self.hf_config.hidden_size

    def get_head_size(self) -> int:
        # FIXME(woosuk): This may not be true for all models.
        return self.hf_config.hidden_size // self.hf_config.num_attention_heads

208
209
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
210
        # For GPTBigCode & Falcon:
211
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
212
213
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
214
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
215
        new_decoder_arch_falcon = (
216
            self.hf_config.model_type in falcon_model_types
217
218
219
            and getattr(self.hf_config, "new_decoder_architecture", False))
        if not new_decoder_arch_falcon and getattr(self.hf_config,
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
220
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
221
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
222
            return 1
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_config.num_attention_heads

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
251
252
253
254
255
256
257

    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
        total_num_hidden_layers = self.hf_config.num_hidden_layers
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
258
259
260
261
262
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
263
            vLLM execution.
264
265
        swap_space: Size of the CPU swap space per GPU (in GiB).
    """
266

267
268
269
270
271
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
272
        sliding_window: Optional[int] = None,
273
274
275
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
276
        self.swap_space_bytes = swap_space * _GB
277
        self.sliding_window = sliding_window
278
        self._verify_args()
279
280
281
282
283

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

300
301
302
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
303
304
305
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
306
            logger.warning("Possibly too large swap space. " + msg)
307

308
309

class ParallelConfig:
310
311
312
313
314
315
316
317
318
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
        worker_use_ray: Whether to use Ray for model workers. Will be set to
            True if either pipeline_parallel_size or tensor_parallel_size is
            greater than 1.
    """
319

320
321
322
323
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
324
        worker_use_ray: bool,
325
        max_parallel_loading_workers: Optional[int] = None,
326
327
328
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
        self.tensor_parallel_size = tensor_parallel_size
329
        self.worker_use_ray = worker_use_ray
330
        self.max_parallel_loading_workers = max_parallel_loading_workers
331
332
333

        self.world_size = pipeline_parallel_size * tensor_parallel_size
        if self.world_size > 1:
334
            self.worker_use_ray = True
335
336
337
338
339
340
341
342
343
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")


class SchedulerConfig:
344
345
346
347
348
349
350
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
351
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
352
            and generated text).
353
        max_paddings: Maximum number of paddings to be added to a batch.
354
    """
355

356
357
358
359
360
    def __init__(
        self,
        max_num_batched_tokens: Optional[int],
        max_num_seqs: int,
        max_model_len: int,
361
        max_paddings: int,
362
363
364
365
366
367
368
    ) -> None:
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
            # If max_model_len is too short, use 2048 as the default value for
            # higher throughput.
            self.max_num_batched_tokens = max(max_model_len, 2048)
369
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
370
        self.max_model_len = max_model_len
371
        self.max_paddings = max_paddings
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        self._verify_args()

    def _verify_args(self) -> None:
        if self.max_num_batched_tokens < self.max_model_len:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
388
389
390
391
392
393
394
395
396
397


_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

398
399
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]

400
401
402

def _get_and_verify_dtype(
    config: PretrainedConfig,
403
    dtype: Union[str, torch.dtype],
404
405
406
407
408
409
410
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

411
412
413
414
415
416
417
418
419
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
                # Following the common practice, we use float16 for float32
                # models.
                torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
420
        else:
421
422
423
424
425
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
426
    else:
427
        raise ValueError(f"Unknown dtype: {dtype}")
428

429
430
431
432
433
434
435
436
    if is_hip() and torch_dtype == torch.float32:
        rocm_supported_dtypes = [
            k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
            if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
        ]
        raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
                         f"Supported dtypes are {rocm_supported_dtypes}")

437
438
439
440
441
442
443
444
445
    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
446
            # Casting between float16 and bfloat16 is allowed with a warning.
447
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
448
449

    return torch_dtype
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
465
466
        # ChatGLM2
        "seq_length",
467
468
469
470
471
472
473
474
475
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
    for key in possible_keys:
        max_len_key = getattr(hf_config, key, None)
        if max_len_key is not None:
            derived_max_model_len = min(derived_max_model_len, max_len_key)
476
    if derived_max_model_len == float("inf"):
477
478
479
480
481
482
483
484
485
486
487
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
            f"{possible_keys}. Assuming the model's maximum length is "
            f"{default_max_len}.")
        derived_max_model_len = default_max_len
488

489
490
491
492
    rope_scaling = getattr(hf_config, "rope_scaling", None)
    if rope_scaling is not None:
        assert "factor" in rope_scaling
        scaling_factor = rope_scaling["factor"]
Antoni Baum's avatar
Antoni Baum committed
493
494
495
        if rope_scaling["type"] == "yarn":
            derived_max_model_len = rope_scaling[
                "original_max_position_embeddings"]
496
497
        derived_max_model_len *= scaling_factor

498
499
500
501
502
503
504
505
506
    if max_model_len is None:
        max_model_len = derived_max_model_len
    elif max_model_len > derived_max_model_len:
        raise ValueError(
            f"User-specified max_model_len ({max_model_len}) is greater than "
            f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
            " in model's config.json). This may lead to incorrect model "
            "outputs or CUDA errors. Make sure the value is correct and "
            "within the model context size.")
507
    return int(max_model_len)