config.py 28.7 KB
Newer Older
1
2
from typing import Optional, Union, ClassVar
from dataclasses import dataclass
3
import os
4
from packaging.version import Version
5
6

import torch
7
from transformers import PretrainedConfig
8

Woosuk Kwon's avatar
Woosuk Kwon committed
9
from vllm.logger import init_logger
10
from vllm.transformers_utils.config import get_config
11
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
12
13
14

logger = init_logger(__name__)

15
_GB = 1 << 30
16

17
18

class ModelConfig:
19
20
21
22
    """Configuration for the model.

    Args:
        model: Name or path of the huggingface model to use.
23
        tokenizer: Name or path of the huggingface tokenizer to use.
24
25
        tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
            available, and "slow" will always use the slow tokenizer.
26
27
        trust_remote_code: Trust remote code (e.g., from HuggingFace) when
            downloading the model and tokenizer.
28
29
        download_dir: Directory to download and load the weights, default to the
            default cache directory of huggingface.
30
31
32
33
34
35
36
37
38
39
        load_format: The format of the model weights to load:
            "auto" will try to load the weights in the safetensors format and
                fall back to the pytorch bin format if safetensors format is
                not available.
            "pt" will load the weights in the pytorch bin format.
            "safetensors" will load the weights in the safetensors format.
            "npcache" will load the weights in pytorch format and store
                a numpy cache to speed up the loading.
            "dummy" will initialize the weights with random values, which is
                mainly for profiling.
40
41
42
43
        dtype: Data type for model weights and activations. The "auto" option
            will use FP16 precision for FP32 and FP16 models, and BF16 precision
            for BF16 models.
        seed: Random seed for reproducibility.
Jasmond L's avatar
Jasmond L committed
44
45
46
        revision: The specific model version to use. It can be a branch name,
            a tag name, or a commit id. If unspecified, will use the default
            version.
47
        code_revision: The specific revision to use for the model code on
48
            Hugging Face Hub. It can be a branch name, a tag name, or a
49
            commit id. If unspecified, will use the default version.
50
51
52
        tokenizer_revision: The specific tokenizer version to use. It can be a
            branch name, a tag name, or a commit id. If unspecified, will use
            the default version.
53
54
        max_model_len: Maximum length of a sequence (including prompt and
            output). If None, will be derived from the model.
55
56
        quantization: Quantization method that was used to quantize the model
            weights. If None, we assume the model weights are not quantized.
57
58
59
60
61
62
        enforce_eager: Whether to enforce eager execution. If True, we will
            disable CUDA graph and always execute the model in eager mode.
            If False, we will use CUDA graph and eager execution in hybrid.
        max_context_len_to_capture: Maximum context len covered by CUDA graphs.
            When a sequence has context length larger than this, we fall back
            to eager mode.
63
    """
64
65
66
67

    def __init__(
        self,
        model: str,
68
69
        tokenizer: str,
        tokenizer_mode: str,
70
        trust_remote_code: bool,
71
        download_dir: Optional[str],
72
        load_format: str,
73
        dtype: Union[str, torch.dtype],
74
        seed: int,
75
        revision: Optional[str] = None,
76
        code_revision: Optional[str] = None,
77
        tokenizer_revision: Optional[str] = None,
78
        max_model_len: Optional[int] = None,
79
        quantization: Optional[str] = None,
80
81
        enforce_eager: bool = False,
        max_context_len_to_capture: Optional[int] = None,
82
        max_logprobs: int = 5,
83
84
    ) -> None:
        self.model = model
85
        self.tokenizer = tokenizer
86
        self.tokenizer_mode = tokenizer_mode
87
        self.trust_remote_code = trust_remote_code
88
        self.download_dir = download_dir
89
        self.load_format = load_format
90
        self.seed = seed
Jasmond L's avatar
Jasmond L committed
91
        self.revision = revision
92
        self.code_revision = code_revision
93
        self.tokenizer_revision = tokenizer_revision
94
        self.quantization = quantization
95
96
        self.enforce_eager = enforce_eager
        self.max_context_len_to_capture = max_context_len_to_capture
97
        self.max_logprobs = max_logprobs
98

99
100
101
102
        if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
            # download model from ModelScope hub,
            # lazy import so that modelscope is not required for normal use.
            from modelscope.hub.snapshot_download import snapshot_download  # pylint: disable=C
103
104
105
106
107
108
            if not os.path.exists(model):
                model_path = snapshot_download(model_id=model,
                                               cache_dir=download_dir,
                                               revision=revision)
            else:
                model_path = model
109
110
111
112
            self.model = model_path
            self.download_dir = model_path
            self.tokenizer = model_path

113
114
        self.hf_config = get_config(self.model, trust_remote_code, revision,
                                    code_revision)
115
        self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
116
117
        self.max_model_len = _get_and_verify_max_len(self.hf_config,
                                                     max_model_len)
118
        self._verify_load_format()
119
        self._verify_tokenizer_mode()
120
        self._verify_quantization()
121
        self._verify_cuda_graph()
122

123
124
    def _verify_load_format(self) -> None:
        load_format = self.load_format.lower()
125
126
127
        supported_load_format = [
            "auto", "pt", "safetensors", "npcache", "dummy"
        ]
kliuae's avatar
kliuae committed
128
        rocm_not_supported_load_format = []
129
        if load_format not in supported_load_format:
130
131
132
            raise ValueError(
                f"Unknown load format: {self.load_format}. Must be one of "
                "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
kliuae's avatar
kliuae committed
133
134
135
136
137
138
139
140
141
        if is_hip() and load_format in rocm_not_supported_load_format:
            rocm_supported_load_format = [
                f for f in supported_load_format
                if (f not in rocm_not_supported_load_format)
            ]
            raise ValueError(
                f"load format \'{load_format}\' is not supported in ROCm. "
                f"Supported load format are "
                f"{rocm_supported_load_format}")
142

143
        # TODO: Remove this check once HF updates the pt weights of Mixtral.
144
        architectures = getattr(self.hf_config, "architectures", [])
Roy's avatar
Roy committed
145
146
147
148
        if "MixtralForCausalLM" in architectures and load_format == "pt":
            raise ValueError(
                "Currently, the 'pt' format is not supported for Mixtral. "
                "Please use the 'safetensors' format instead. ")
149
150
        self.load_format = load_format

151
152
153
154
155
156
157
    def _verify_tokenizer_mode(self) -> None:
        tokenizer_mode = self.tokenizer_mode.lower()
        if tokenizer_mode not in ["auto", "slow"]:
            raise ValueError(
                f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
                "either 'auto' or 'slow'.")
        self.tokenizer_mode = tokenizer_mode
158

159
    def _verify_quantization(self) -> None:
160
161
        supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
        rocm_not_supported_quantization = ["awq", "marlin"]
162
163
164
165
166
167
        if self.quantization is not None:
            self.quantization = self.quantization.lower()

        # Parse quantization method from the HF model config, if available.
        hf_quant_config = getattr(self.hf_config, "quantization_config", None)
        if hf_quant_config is not None:
168

169
            hf_quant_method = str(hf_quant_config["quant_method"]).lower()
170
171
172
173
174
            # If the GPTQ model is serialized in marlin format, use marlin.
            if (hf_quant_method == "gptq"
                    and "is_marlin_format" in hf_quant_config
                    and hf_quant_config["is_marlin_format"]):
                hf_quant_method = "marlin"
175
176
177
178
179
180
181
182
183
184
185
186
187
188
            if self.quantization is None:
                self.quantization = hf_quant_method
            elif self.quantization != hf_quant_method:
                raise ValueError(
                    "Quantization method specified in the model config "
                    f"({hf_quant_method}) does not match the quantization "
                    f"method specified in the `quantization` argument "
                    f"({self.quantization}).")

        if self.quantization is not None:
            if self.quantization not in supported_quantization:
                raise ValueError(
                    f"Unknown quantization method: {self.quantization}. Must "
                    f"be one of {supported_quantization}.")
189
190
191
            if is_hip(
            ) and self.quantization in rocm_not_supported_quantization:
                raise ValueError(
192
193
                    f"{self.quantization} quantization is currently not "
                    f"supported in ROCm.")
194
195
196
197
198
            if self.quantization != "marlin":
                logger.warning(
                    f"{self.quantization} quantization is not fully "
                    "optimized yet. The speed can be slower than "
                    "non-quantized models.")
199

200
201
202
203
204
205
    def _verify_cuda_graph(self) -> None:
        if self.max_context_len_to_capture is None:
            self.max_context_len_to_capture = self.max_model_len
        self.max_context_len_to_capture = min(self.max_context_len_to_capture,
                                              self.max_model_len)

206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_num_attention_heads = self.hf_config.num_attention_heads
        tensor_parallel_size = parallel_config.tensor_parallel_size
        if total_num_attention_heads % tensor_parallel_size != 0:
            raise ValueError(
                f"Total number of attention heads ({total_num_attention_heads})"
                " must be divisible by tensor parallel size "
                f"({tensor_parallel_size}).")

        total_num_hidden_layers = self.hf_config.num_hidden_layers
        pipeline_parallel_size = parallel_config.pipeline_parallel_size
        if total_num_hidden_layers % pipeline_parallel_size != 0:
            raise ValueError(
                f"Total number of hidden layers ({total_num_hidden_layers}) "
                "must be divisible by pipeline parallel size "
                f"({pipeline_parallel_size}).")

226
227
228
229
230
231
    def get_sliding_window(self) -> Optional[int]:
        return getattr(self.hf_config, "sliding_window", None)

    def get_vocab_size(self) -> int:
        return self.hf_config.vocab_size

232
233
234
235
    def get_hidden_size(self) -> int:
        return self.hf_config.hidden_size

    def get_head_size(self) -> int:
236
237
        if hasattr(self.hf_config, "head_dim"):
            return self.hf_config.head_dim
238
239
240
        # FIXME(woosuk): This may not be true for all models.
        return self.hf_config.hidden_size // self.hf_config.num_attention_heads

241
242
    def get_total_num_kv_heads(self) -> int:
        """Returns the total number of KV heads."""
Zhuohan Li's avatar
Zhuohan Li committed
243
        # For GPTBigCode & Falcon:
244
        # NOTE: for falcon, when new_decoder_architecture is True, the
Zhuohan Li's avatar
Zhuohan Li committed
245
246
        # multi_query flag is ignored and we use n_head_kv for the number of
        # KV heads.
247
        falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
248
        new_decoder_arch_falcon = (
249
            self.hf_config.model_type in falcon_model_types
250
251
252
            and getattr(self.hf_config, "new_decoder_architecture", False))
        if not new_decoder_arch_falcon and getattr(self.hf_config,
                                                   "multi_query", False):
Zhuohan Li's avatar
Zhuohan Li committed
253
            # Multi-query attention, only one KV head.
Woosuk Kwon's avatar
Woosuk Kwon committed
254
            # Currently, tensor parallelism is not supported in this case.
Zhuohan Li's avatar
Zhuohan Li committed
255
            return 1
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283

        attributes = [
            # For Falcon:
            "n_head_kv",
            "num_kv_heads",
            # For LLaMA-2:
            "num_key_value_heads",
            # For ChatGLM:
            "multi_query_group_num",
        ]
        for attr in attributes:
            num_kv_heads = getattr(self.hf_config, attr, None)
            if num_kv_heads is not None:
                return num_kv_heads

        # For non-grouped-query attention models, the number of KV heads is
        # equal to the number of attention heads.
        return self.hf_config.num_attention_heads

    def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
        """Returns the number of KV heads per GPU."""
        total_num_kv_heads = self.get_total_num_kv_heads()
        # If tensor parallelism is used, we divide the number of KV heads by
        # the tensor parallel size. We will replicate the KV heads in the
        # case where the number of KV heads is smaller than the tensor
        # parallel size so each GPU has at least one KV head.
        return max(1,
                   total_num_kv_heads // parallel_config.tensor_parallel_size)
284
285
286
287
288
289
290

    def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
        total_num_hidden_layers = self.hf_config.num_hidden_layers
        return total_num_hidden_layers // parallel_config.pipeline_parallel_size


class CacheConfig:
291
292
293
294
295
    """Configuration for the KV cache.

    Args:
        block_size: Size of a cache block in number of tokens.
        gpu_memory_utilization: Fraction of GPU memory to use for the
Woosuk Kwon's avatar
Woosuk Kwon committed
296
            vLLM execution.
297
        swap_space: Size of the CPU swap space per GPU (in GiB).
298
        cache_dtype: Data type for kv cache storage.
299
    """
300

301
302
303
304
305
    def __init__(
        self,
        block_size: int,
        gpu_memory_utilization: float,
        swap_space: int,
306
        cache_dtype: str,
307
        sliding_window: Optional[int] = None,
308
        enable_prefix_caching: bool = False,
309
310
311
    ) -> None:
        self.block_size = block_size
        self.gpu_memory_utilization = gpu_memory_utilization
312
        self.swap_space_bytes = swap_space * _GB
313
        self.cache_dtype = cache_dtype
314
        self.sliding_window = sliding_window
315
        self.enable_prefix_caching = enable_prefix_caching
316
        self._verify_args()
317
        self._verify_cache_dtype()
318
319
320
321
322

        # Will be set after profiling.
        self.num_gpu_blocks = None
        self.num_cpu_blocks = None

323
    def metrics_info(self):
324
325
        # convert cache_config to dict(key: str, value: str) for prometheus
        # metrics info
326
327
        return {key: str(value) for key, value in self.__dict__.items()}

328
329
330
331
332
333
    def _verify_args(self) -> None:
        if self.gpu_memory_utilization > 1.0:
            raise ValueError(
                "GPU memory utilization must be less than 1.0. Got "
                f"{self.gpu_memory_utilization}.")

334
335
336
337
    def _verify_cache_dtype(self) -> None:
        if self.cache_dtype == "auto":
            pass
        elif self.cache_dtype == "fp8_e5m2":
338
339
340
            if is_hip():
                raise NotImplementedError(
                    "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
341
            nvcc_cuda_version = get_nvcc_cuda_version()
342
            if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
343
344
345
346
347
348
349
350
351
352
353
354
                raise ValueError(
                    "FP8 is not supported when cuda version is lower than 11.8."
                )
            logger.info(
                "Using fp8_e5m2 data type to store kv cache. It reduces "
                "the GPU memory footprint and boosts the performance. "
                "But it may cause slight accuracy drop. "
                "Currently we only support fp8 without scaling factors and "
                "make e5m2 as a default format.")
        else:
            raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

355
356
357
358
359
360
361
362
363
364
    def verify_with_parallel_config(
        self,
        parallel_config: "ParallelConfig",
    ) -> None:
        total_cpu_memory = get_cpu_memory()
        # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
        # group are in the same node. However, the GPUs may span multiple nodes.
        num_gpus_per_node = parallel_config.tensor_parallel_size
        cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node

365
366
367
        msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
               f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
               "allocated for the swap space.")
368
369
370
        if cpu_memory_usage > 0.7 * total_cpu_memory:
            raise ValueError("Too large swap space. " + msg)
        elif cpu_memory_usage > 0.4 * total_cpu_memory:
371
            logger.warning("Possibly too large swap space. " + msg)
372

373
374

class ParallelConfig:
375
376
377
378
379
380
381
382
    """Configuration for the distributed execution.

    Args:
        pipeline_parallel_size: Number of pipeline parallel groups.
        tensor_parallel_size: Number of tensor parallel groups.
        worker_use_ray: Whether to use Ray for model workers. Will be set to
            True if either pipeline_parallel_size or tensor_parallel_size is
            greater than 1.
zspo's avatar
zspo committed
383
384
385
        max_parallel_loading_workers: Maximum number of multiple batches
            when load model sequentially. To avoid RAM OOM when using tensor
            parallel and large models.
386
387
        disable_custom_all_reduce: Disable the custom all-reduce kernel and
            fall back to NCCL.
388
389
        ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
            https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
390
    """
391

392
393
394
395
    def __init__(
        self,
        pipeline_parallel_size: int,
        tensor_parallel_size: int,
396
        worker_use_ray: bool,
397
        max_parallel_loading_workers: Optional[int] = None,
398
        disable_custom_all_reduce: bool = False,
399
        ray_workers_use_nsight: bool = False,
400
401
    ) -> None:
        self.pipeline_parallel_size = pipeline_parallel_size
402
        if is_neuron():
403
404
405
            # For Neuron device support, here we assign TP=1 to avoid sharding
            # within vLLM directly. Transformer-neuronx would take
            # neuron_tp_degree attribute, and distribute the workload
406
407
408
409
410
            # to multiple NeuronCores.
            self.tensor_parallel_size = 1
            self.neuron_tp_degree = tensor_parallel_size
        else:
            self.tensor_parallel_size = tensor_parallel_size
411
        self.worker_use_ray = worker_use_ray
412
        self.max_parallel_loading_workers = max_parallel_loading_workers
413
        self.disable_custom_all_reduce = disable_custom_all_reduce
414
        self.ray_workers_use_nsight = ray_workers_use_nsight
415

416
417
418
        self.world_size = pipeline_parallel_size * self.tensor_parallel_size
        # Ray worker is not supported for Neuron backend.
        if self.world_size > 1 and not is_neuron():
419
            self.worker_use_ray = True
420
421
422
423
424
425
        self._verify_args()

    def _verify_args(self) -> None:
        if self.pipeline_parallel_size > 1:
            raise NotImplementedError(
                "Pipeline parallelism is not supported yet.")
426
427
428
429
430
431
432
433
434
435
436
        if not self.disable_custom_all_reduce and self.world_size > 1:
            if is_hip():
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported on AMD GPUs.")
            elif self.pipeline_parallel_size > 1:
                self.disable_custom_all_reduce = True
                logger.info(
                    "Disabled the custom all-reduce kernel because it is not "
                    "supported with pipeline parallelism.")
437
438
439
        if self.ray_workers_use_nsight and not self.worker_use_ray:
            raise ValueError("Unable to use nsight profiling unless workers "
                             "run with Ray.")
440
441
442
443

        # FIXME(woosuk): Fix the stability issues and re-enable the custom
        # all-reduce kernel.
        if not self.disable_custom_all_reduce and self.world_size > 1:
444
445
            self.disable_custom_all_reduce = True
            logger.info(
446
447
448
                "Custom all-reduce kernels are temporarily disabled due to "
                "stability issues. We will re-enable them once the issues are "
                "resolved.")
449
450
451


class SchedulerConfig:
452
453
454
455
456
457
458
    """Scheduler configuration.

    Args:
        max_num_batched_tokens: Maximum number of tokens to be processed in
            a single iteration.
        max_num_seqs: Maximum number of sequences to be processed in a single
            iteration.
Chaofan Lin's avatar
Chaofan Lin committed
459
        max_model_len: Maximum length of a sequence (including prompt
Lily Liu's avatar
Lily Liu committed
460
            and generated text).
461
        max_paddings: Maximum number of paddings to be added to a batch.
462
    """
463

464
465
466
467
468
    def __init__(
        self,
        max_num_batched_tokens: Optional[int],
        max_num_seqs: int,
        max_model_len: int,
469
        max_paddings: int,
470
471
472
473
474
475
476
    ) -> None:
        if max_num_batched_tokens is not None:
            self.max_num_batched_tokens = max_num_batched_tokens
        else:
            # If max_model_len is too short, use 2048 as the default value for
            # higher throughput.
            self.max_num_batched_tokens = max(max_model_len, 2048)
477
        self.max_num_seqs = max_num_seqs
Lily Liu's avatar
Lily Liu committed
478
        self.max_model_len = max_model_len
479
        self.max_paddings = max_paddings
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
        self._verify_args()

    def _verify_args(self) -> None:
        if self.max_num_batched_tokens < self.max_model_len:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
                f"smaller than max_model_len ({self.max_model_len}). "
                "This effectively limits the maximum sequence length to "
                "max_num_batched_tokens and makes vLLM reject longer "
                "sequences. Please increase max_num_batched_tokens or "
                "decrease max_model_len.")
        if self.max_num_batched_tokens < self.max_num_seqs:
            raise ValueError(
                f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
                "be greater than or equal to max_num_seqs "
                f"({self.max_num_seqs}).")
496
497


498
499
class DeviceConfig:

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    def __init__(self, device: str = "auto") -> None:
        if device == "auto":
            # Automated device type detection
            if torch.cuda.is_available():
                self.device_type = "cuda"
            elif is_neuron():
                self.device_type = "neuron"
            else:
                raise RuntimeError("No supported device detected.")
        else:
            # Device type is assigned explicitly
            self.device_type = device

        # Some device types require processing inputs on CPU
        if self.device_type in ["neuron"]:
            self.device = torch.device("cpu")
        else:
            # Set device with device type
            self.device = torch.device(self.device_type)

    @property
    def is_neuron(self):
        return self.device_type == "neuron"
523
524


525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
@dataclass
class LoRAConfig:
    max_lora_rank: int
    max_loras: int
    max_cpu_loras: Optional[int] = None
    lora_dtype: Optional[torch.dtype] = None
    lora_extra_vocab_size: int = 256
    # This is a constant.
    lora_vocab_padding_size: ClassVar[int] = 256

    def __post_init__(self):
        # Keep this in sync with csrc/punica/bgmv/bgmv_config.h
        possible_max_ranks = (8, 16, 32, 64)
        possible_lora_extra_vocab_size = (0, 256, 512)
        if self.max_lora_rank not in possible_max_ranks:
            raise ValueError(
                f"max_lora_rank ({self.max_lora_rank}) must be one of "
                f"{possible_max_ranks}.")
        if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
            raise ValueError(
                f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
                f"must be one of {possible_lora_extra_vocab_size}.")
        if self.max_loras < 1:
            raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
        if self.max_cpu_loras is None:
            self.max_cpu_loras = self.max_loras
        elif self.max_cpu_loras < self.max_loras:
            raise ValueError(
                f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
zspo's avatar
zspo committed
554
                f"max_loras ({self.max_loras})")
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572

    def verify_with_model_config(self, model_config: ModelConfig):
        if self.lora_dtype in (None, "auto"):
            self.lora_dtype = model_config.dtype
        elif isinstance(self.lora_dtype, str):
            self.lora_dtype = getattr(torch, self.lora_dtype)
        if model_config.quantization is not None:
            raise ValueError(
                "LoRA is not supported with quantized models yet.")

    def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
        if scheduler_config.max_num_batched_tokens > 65528:
            raise ValueError(
                "Due to limitations of the custom LoRA CUDA kernel, "
                "max_num_batched_tokens must be <= 65528 when "
                "LoRA is enabled.")


573
574
575
576
577
578
579
580
_STR_DTYPE_TO_TORCH_DTYPE = {
    "half": torch.float16,
    "float16": torch.float16,
    "float": torch.float32,
    "float32": torch.float32,
    "bfloat16": torch.bfloat16,
}

581
582
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]

583
584
585

def _get_and_verify_dtype(
    config: PretrainedConfig,
586
    dtype: Union[str, torch.dtype],
587
588
589
590
591
592
593
) -> torch.dtype:
    # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
    # because config.torch_dtype can be None.
    config_dtype = getattr(config, "torch_dtype", None)
    if config_dtype is None:
        config_dtype = torch.float32

594
595
596
597
598
599
600
601
602
    if isinstance(dtype, str):
        dtype = dtype.lower()
        if dtype == "auto":
            if config_dtype == torch.float32:
                # Following the common practice, we use float16 for float32
                # models.
                torch_dtype = torch.float16
            else:
                torch_dtype = config_dtype
603
        else:
604
605
606
607
608
            if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
                raise ValueError(f"Unknown dtype: {dtype}")
            torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
    elif isinstance(dtype, torch.dtype):
        torch_dtype = dtype
609
    else:
610
        raise ValueError(f"Unknown dtype: {dtype}")
611

612
613
614
615
616
617
618
619
    if is_hip() and torch_dtype == torch.float32:
        rocm_supported_dtypes = [
            k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
            if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
        ]
        raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
                         f"Supported dtypes are {rocm_supported_dtypes}")

620
621
622
623
624
625
626
627
628
    # Verify the dtype.
    if torch_dtype != config_dtype:
        if torch_dtype == torch.float32:
            # Upcasting to float32 is allowed.
            pass
        elif config_dtype == torch.float32:
            # Downcasting from float32 to float16 or bfloat16 is allowed.
            pass
        else:
Woosuk Kwon's avatar
Woosuk Kwon committed
629
            # Casting between float16 and bfloat16 is allowed with a warning.
630
            logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
631
632

    return torch_dtype
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647


def _get_and_verify_max_len(
    hf_config: PretrainedConfig,
    max_model_len: Optional[int],
) -> int:
    """Get and verify the model's maximum length."""
    derived_max_model_len = float("inf")
    possible_keys = [
        # OPT
        "max_position_embeddings",
        # GPT-2
        "n_positions",
        # MPT
        "max_seq_len",
648
649
        # ChatGLM2
        "seq_length",
650
651
652
653
654
655
656
657
658
        # Others
        "max_sequence_length",
        "max_seq_length",
        "seq_len",
    ]
    for key in possible_keys:
        max_len_key = getattr(hf_config, key, None)
        if max_len_key is not None:
            derived_max_model_len = min(derived_max_model_len, max_len_key)
659
    if derived_max_model_len == float("inf"):
660
661
662
663
664
665
666
667
668
669
670
        if max_model_len is not None:
            # If max_model_len is specified, we use it.
            return max_model_len

        default_max_len = 2048
        logger.warning(
            "The model's config.json does not contain any of the following "
            "keys to determine the original maximum length of the model: "
            f"{possible_keys}. Assuming the model's maximum length is "
            f"{default_max_len}.")
        derived_max_model_len = default_max_len
671

672
673
674
675
    rope_scaling = getattr(hf_config, "rope_scaling", None)
    if rope_scaling is not None:
        assert "factor" in rope_scaling
        scaling_factor = rope_scaling["factor"]
Antoni Baum's avatar
Antoni Baum committed
676
677
678
        if rope_scaling["type"] == "yarn":
            derived_max_model_len = rope_scaling[
                "original_max_position_embeddings"]
679
680
        derived_max_model_len *= scaling_factor

681
682
683
684
685
686
687
688
689
    if max_model_len is None:
        max_model_len = derived_max_model_len
    elif max_model_len > derived_max_model_len:
        raise ValueError(
            f"User-specified max_model_len ({max_model_len}) is greater than "
            f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
            " in model's config.json). This may lead to incorrect model "
            "outputs or CUDA errors. Make sure the value is correct and "
            "within the model context size.")
690
    return int(max_model_len)