"vscode:/vscode.git/clone" did not exist on "16851efa0f5248b27f86eb69dbc56900d1cc1c94"
auto.py 5.65 KB
Newer Older
1
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Adapted from
https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
"""
18

19
20
21
22
import warnings
from typing import Dict, Optional, Union

from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
23
24
25
26
27
28
from .gguf import GGUFQuantizer
from .quantization_config import (
    BitsAndBytesConfig,
    GGUFQuantizationConfig,
    QuantizationConfigMixin,
    QuantizationMethod,
29
    QuantoConfig,
30
31
    TorchAoConfig,
)
32
from .quanto import QuantoQuantizer
Aryan's avatar
Aryan committed
33
from .torchao import TorchAoHfQuantizer
34
35
36
37
38


AUTO_QUANTIZER_MAPPING = {
    "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
    "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
39
    "gguf": GGUFQuantizer,
40
    "quanto": QuantoQuantizer,
Aryan's avatar
Aryan committed
41
    "torchao": TorchAoHfQuantizer,
42
43
44
45
46
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
    "bitsandbytes_4bit": BitsAndBytesConfig,
    "bitsandbytes_8bit": BitsAndBytesConfig,
47
    "gguf": GGUFQuantizationConfig,
48
    "quanto": QuantoConfig,
Aryan's avatar
Aryan committed
49
    "torchao": TorchAoConfig,
50
51
52
}


53
class DiffusersAutoQuantizer:
54
    """
55
56
     The auto diffusers quantizer class that takes care of automatically instantiating to the correct
    `DiffusersQuantizer` given the `QuantizationConfig`.
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
    """

    @classmethod
    def from_dict(cls, quantization_config_dict: Dict):
        quant_method = quantization_config_dict.get("quant_method", None)
        # We need a special care for bnb models to make sure everything is BC ..
        if quantization_config_dict.get("load_in_8bit", False) or quantization_config_dict.get("load_in_4bit", False):
            suffix = "_4bit" if quantization_config_dict.get("load_in_4bit", False) else "_8bit"
            quant_method = QuantizationMethod.BITS_AND_BYTES + suffix
        elif quant_method is None:
            raise ValueError(
                "The model's quantization config from the arguments has no `quant_method` attribute. Make sure that the model has been correctly quantized"
            )

        if quant_method not in AUTO_QUANTIZATION_CONFIG_MAPPING.keys():
            raise ValueError(
                f"Unknown quantization type, got {quant_method} - supported types are:"
                f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
            )

        target_cls = AUTO_QUANTIZATION_CONFIG_MAPPING[quant_method]
        return target_cls.from_dict(quantization_config_dict)

    @classmethod
    def from_config(cls, quantization_config: Union[QuantizationConfigMixin, Dict], **kwargs):
        # Convert it to a QuantizationConfig if the q_config is a dict
        if isinstance(quantization_config, dict):
84
            quantization_config = cls.from_dict(quantization_config)
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

        quant_method = quantization_config.quant_method

        # Again, we need a special care for bnb as we have a single quantization config
        # class for both 4-bit and 8-bit quantization
        if quant_method == QuantizationMethod.BITS_AND_BYTES:
            if quantization_config.load_in_8bit:
                quant_method += "_8bit"
            else:
                quant_method += "_4bit"

        if quant_method not in AUTO_QUANTIZER_MAPPING.keys():
            raise ValueError(
                f"Unknown quantization type, got {quant_method} - supported types are:"
                f" {list(AUTO_QUANTIZER_MAPPING.keys())}"
            )

        target_cls = AUTO_QUANTIZER_MAPPING[quant_method]
        return target_cls(quantization_config, **kwargs)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
107
108
109
110
111
112
113
114
115
116
        model_config = cls.load_config(pretrained_model_name_or_path, **kwargs)
        if getattr(model_config, "quantization_config", None) is None:
            raise ValueError(
                f"Did not found a `quantization_config` in {pretrained_model_name_or_path}. Make sure that the model is correctly quantized."
            )
        quantization_config_dict = model_config.quantization_config
        quantization_config = cls.from_dict(quantization_config_dict)
        # Update with potential kwargs that are passed through from_pretrained.
        quantization_config.update(kwargs)

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        return cls.from_config(quantization_config)

    @classmethod
    def merge_quantization_configs(
        cls,
        quantization_config: Union[dict, QuantizationConfigMixin],
        quantization_config_from_args: Optional[QuantizationConfigMixin],
    ):
        """
        handles situations where both quantization_config from args and quantization_config from model config are
        present.
        """
        if quantization_config_from_args is not None:
            warning_msg = (
                "You passed `quantization_config` or equivalent parameters to `from_pretrained` but the model you're loading"
                " already has a `quantization_config` attribute. The `quantization_config` from the model will be used."
            )
        else:
            warning_msg = ""

        if isinstance(quantization_config, dict):
138
            quantization_config = cls.from_dict(quantization_config)
139
140
141
142
143

        if warning_msg != "":
            warnings.warn(warning_msg)

        return quantization_config