"src/targets/vscode:/vscode.git/clone" did not exist on "8beb6680a9dab66d3a437c75cbd7fd7649743c1f"
config.py 23.1 KB
Newer Older
1
2
from typing import Optional, Union, ClassVar
from dataclasses import dataclass
3
import os
4
5

import torch
6
from transformers import PretrainedConfig
7

Woosuk Kwon's avatar
Woosuk Kwon committed
8
from vllm.logger import init_logger
9
from vllm.transformers_utils.config import get_config
10
from vllm.utils import get_cpu_memory, is_hip
11
12
13

logger = init_logger(__name__)

14
_GB = 1 << 30
15

16
17

class ModelConfig:
18
19
20
21
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
22
        tokenizer: Name or path of the huggingface tokenizer to use.
23
24
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
25
26
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
27
28
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
29
30
31
32
33
34
35
36
37
38
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
39
40
41
42
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
43
44
45
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
46
47
48
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
49
50
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
51
52
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
53
54
55
56
57
58
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
59
    """
60
61
62
63

    def __init__(
        self,
        model: str,
64
65
        tokenizer: str,
        tokenizer_mode: str,
66
        trust_remote_code: bool,
67
        download_dir: Optional[str],
68
        load_format: str,
69
        dtype: Union[str, torch.dtype],
70
        seed: int,
71
        revision: Optional[str] = None,
72
        tokenizer_revision: Optional[str] = None,
73
        max_model_len: Optional[int] = None,
74
        quantization: Optional[str] = None,
75
76
        enforce_eager: bool = False,
        max_context_len_to_capture: Optional[int] = None,
77
78
    ) -> None:
        self.model = model
79
        self.tokenizer = tokenizer
80
        self.tokenizer_mode = tokenizer_mode
81
        self.trust_remote_code = trust_remote_code
82
        self.download_dir = download_dir
83
        self.load_format = load_format
84
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
85
        self.revision = revision
86
        self.tokenizer_revision = tokenizer_revision
87
        self.quantization = quantization
88
89
        self.enforce_eager = enforce_eager
        self.max_context_len_to_capture = max_context_len_to_capture
90

91
92
93
94
95
96
97
98
99
100
101
102
        if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
            model_path = snapshot_download(model_id=model,
                                           cache_dir=download_dir,
                                           revision=revision)
            self.model = model_path
            self.download_dir = model_path
            self.tokenizer = model_path

        self.hf_config = get_config(self.model, trust_remote_code, revision)
103
        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
104
105
        self.max_model_len = _get_and_verify_max_len(self.hf_config,
                                                     max_model_len)
106
        self._verify_load_format()
107
        self._verify_tokenizer_mode()
108
        self._verify_quantization()
109
        self._verify_cuda_graph()
110

111
112
    def _verify_load_format(self) -> None:
        load_format = self.load_format.lower()
113
114
115
        supported_load_format = [
            "auto", "pt", "safetensors", "npcache", "dummy"
        ]
kliuae's avatar
kliuae committed
116
        rocm_not_supported_load_format = []
117
        if load_format not in supported_load_format:
118
119
120
            raise ValueError(
                f"Unknown load format: {self.load_format}. Must be one of "
                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
kliuae's avatar
kliuae committed
121
122
123
124
125
126
127
128
129
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in supported_load_format
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format \'{load_format}\' is not supported in ROCm. "
                f"Supported load format are "
                f"{rocm_supported_load_format}")
130

131
        # TODO: Remove this check once HF updates the pt weights of Mixtral.
132
        architectures = getattr(self.hf_config, "architectures", [])
Roy's avatar
Roy committed
133
134
135
136
        if "MixtralForCausalLM" in architectures and load_format == "pt":
            raise ValueError(
                "Currently, the 'pt' format is not supported for Mixtral. "
                "Please use the 'safetensors' format instead. ")
137
138
        self.load_format = load_format

139
140
141
142
143
144
145
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
146

147
    def _verify_quantization(self) -> None:
CHU Tianxiang's avatar
CHU Tianxiang committed
148
        supported_quantization = ["awq", "gptq", "squeezellm"]
kliuae's avatar
kliuae committed
149
        rocm_not_supported_quantization = ["awq"]
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        hf_quant_config = getattr(self.hf_config, "quantization_config", None)
        if hf_quant_config is not None:
            hf_quant_method = str(hf_quant_config["quant_method"]).lower()
            if self.quantization is None:
                self.quantization = hf_quant_method
            elif self.quantization != hf_quant_method:
                raise ValueError(
                    "Quantization method specified in the model config "
                    f"({hf_quant_method}) does not match the quantization "
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
171
172
173
174
175
            if is_hip(
            ) and self.quantization in rocm_not_supported_quantization:
                raise ValueError(
                    f"{self.quantization} quantization is currently not supported "
                    f"in ROCm.")
176
177
178
            logger.warning(f"{self.quantization} quantization is not fully "
                           "optimized yet. The speed can be slower than "
                           "non-quantized models.")
179

180
181
182
183
184
185
    def _verify_cuda_graph(self) -> None:
        if self.max_context_len_to_capture is None:
            self.max_context_len_to_capture = self.max_model_len
        self.max_context_len_to_capture = min(self.max_context_len_to_capture,
                                              self.max_model_len)

186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_num_attention_heads = self.hf_config.num_attention_heads
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        total_num_hidden_layers = self.hf_config.num_hidden_layers
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

206
207
208
209
210
211
    def get_sliding_window(self) -> Optional[int]:
        return getattr(self.hf_config, "sliding_window", None)

    def get_vocab_size(self) -> int:
        return self.hf_config.vocab_size

212
213
214
215
216
217
218
    def get_hidden_size(self) -> int:
        return self.hf_config.hidden_size

    def get_head_size(self) -> int:
        # FIXME(woosuk): This may not be true for all models.
        return self.hf_config.hidden_size // self.hf_config.num_attention_heads

219
220
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
221
        # For GPTBigCode & Falcon:
222
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
223
224
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
225
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
226
        new_decoder_arch_falcon = (
227
            self.hf_config.model_type in falcon_model_types
228
229
230
            and getattr(self.hf_config, "new_decoder_architecture", False))
        if not new_decoder_arch_falcon and getattr(self.hf_config,
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
231
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
232
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
233
            return 1
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_config.num_attention_heads

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
262
263
264
265
266
267
268

    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
        total_num_hidden_layers = self.hf_config.num_hidden_layers
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
269
270
271
272
273
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
274
            vLLM execution.
275
276
        swap_space: Size of the CPU swap space per GPU (in GiB).
    """
277

278
279
280
281
282
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
283
        sliding_window: Optional[int] = None,
284
285
286
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
287
        self.swap_space_bytes = swap_space * _GB
288
        self.sliding_window = sliding_window
289
        self._verify_args()
290
291
292
293
294

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

311
312
313
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
314
315
316
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
317
            logger.warning("Possibly too large swap space. " + msg)
318

319
320

class ParallelConfig:
321
322
323
324
325
326
327
328
329
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
        worker_use_ray: Whether to use Ray for model workers. Will be set to
            True if either pipeline_parallel_size or tensor_parallel_size is
            greater than 1.
    """
330

331
332
333
334
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
335
        worker_use_ray: bool,
336
        max_parallel_loading_workers: Optional[int] = None,
337
338
339
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
        self.tensor_parallel_size = tensor_parallel_size
340
        self.worker_use_ray = worker_use_ray
341
        self.max_parallel_loading_workers = max_parallel_loading_workers
342
343
344

        self.world_size = pipeline_parallel_size * tensor_parallel_size
        if self.world_size > 1:
345
            self.worker_use_ray = True
346
347
348
349
350
351
352
353
354
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")


class SchedulerConfig:
355
356
357
358
359
360
361
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
362
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
363
            and generated text).
364
        max_paddings: Maximum number of paddings to be added to a batch.
365
    """
366

367
368
369
370
371
    def __init__(
        self,
        max_num_batched_tokens: Optional[int],
        max_num_seqs: int,
        max_model_len: int,
372
        max_paddings: int,
373
374
375
376
377
378
379
    ) -> None:
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
            # If max_model_len is too short, use 2048 as the default value for
            # higher throughput.
            self.max_num_batched_tokens = max(max_model_len, 2048)
380
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
381
        self.max_model_len = max_model_len
382
        self.max_paddings = max_paddings
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
        self._verify_args()

    def _verify_args(self) -> None:
        if self.max_num_batched_tokens < self.max_model_len:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
399
400


401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256

    def __post_init__(self):
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
        possible_max_ranks = (8, 16, 32, 64)
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
                f"max_num_seqs ({self.max_loras})")

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
        if model_config.quantization is not None:
            raise ValueError(
                "LoRA is not supported with quantized models yet.")

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
        if scheduler_config.max_num_batched_tokens > 65528:
            raise ValueError(
                "Due to limitations of the custom LoRA CUDA kernel, "
                "max_num_batched_tokens must be <= 65528 when "
                "LoRA is enabled.")


449
450
451
452
453
454
455
456
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

457
458
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]

459
460
461

def _get_and_verify_dtype(
    config: PretrainedConfig,
462
    dtype: Union[str, torch.dtype],
463
464
465
466
467
468
469
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

470
471
472
473
474
475
476
477
478
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
                # Following the common practice, we use float16 for float32
                # models.
                torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
479
        else:
480
481
482
483
484
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
485
    else:
486
        raise ValueError(f"Unknown dtype: {dtype}")
487

488
489
490
491
492
493
494
495
    if is_hip() and torch_dtype == torch.float32:
        rocm_supported_dtypes = [
            k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
            if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
        ]
        raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
                         f"Supported dtypes are {rocm_supported_dtypes}")

496
497
498
499
500
501
502
503
504
    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
505
            # Casting between float16 and bfloat16 is allowed with a warning.
506
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
507
508

    return torch_dtype
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
524
525
        # ChatGLM2
        "seq_length",
526
527
528
529
530
531
532
533
534
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
    for key in possible_keys:
        max_len_key = getattr(hf_config, key, None)
        if max_len_key is not None:
            derived_max_model_len = min(derived_max_model_len, max_len_key)
535
    if derived_max_model_len == float("inf"):
536
537
538
539
540
541
542
543
544
545
546
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
            f"{possible_keys}. Assuming the model's maximum length is "
            f"{default_max_len}.")
        derived_max_model_len = default_max_len
547

548
549
550
551
    rope_scaling = getattr(hf_config, "rope_scaling", None)
    if rope_scaling is not None:
        assert "factor" in rope_scaling
        scaling_factor = rope_scaling["factor"]
Antoni Baum's avatar
Antoni Baum committed
552
553
554
        if rope_scaling["type"] == "yarn":
            derived_max_model_len = rope_scaling[
                "original_max_position_embeddings"]
555
556
        derived_max_model_len *= scaling_factor

557
558
559
560
561
562
563
564
565
    if max_model_len is None:
        max_model_len = derived_max_model_len
    elif max_model_len > derived_max_model_len:
        raise ValueError(
            f"User-specified max_model_len ({max_model_len}) is greater than "
            f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
            " in model's config.json). This may lead to incorrect model "
            "outputs or CUDA errors. Make sure the value is correct and "
            "within the model context size.")
566
    return int(max_model_len)