peft_utils.py 10.5 KB
Newer Older
1
# Copyright 2024 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PEFT utilities: Utilities related to peft library
"""
17

18
import collections
19
import importlib
20
from typing import Optional
21

22
23
24
from packaging import version

from .import_utils import is_peft_available, is_torch_available
25
26


27
28
29
if is_torch_available():
    import torch

30

31
def recurse_remove_peft_layers(model):
32
33
34
    r"""
    Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
    """
35
    from peft.tuners.tuners_utils import BaseTunerLayer
36

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    has_base_layer_pattern = False
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            has_base_layer_pattern = hasattr(module, "base_layer")
            break

    if has_base_layer_pattern:
        from peft.utils import _get_submodules

        key_list = [key for key, _ in model.named_modules() if "lora" not in key]
        for key in key_list:
            try:
                parent, target, target_name = _get_submodules(model, key)
            except AttributeError:
                continue
            if hasattr(target, "base_layer"):
                setattr(parent, target_name, target.get_base_layer())
    else:
        # This is for backwards compatibility with PEFT <= 0.6.2.
        # TODO can be removed once that PEFT version is no longer supported.
        from peft.tuners.lora import LoraLayer

        for name, module in model.named_children():
            if len(list(module.children())) > 0:
                ## compound module, go inside it
                recurse_remove_peft_layers(module)

            module_replaced = False

            if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
                new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
                    module.weight.device
                )
                new_module.weight = module.weight
                if module.bias is not None:
                    new_module.bias = module.bias

                module_replaced = True
            elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
                new_module = torch.nn.Conv2d(
                    module.in_channels,
                    module.out_channels,
                    module.kernel_size,
                    module.stride,
                    module.padding,
                    module.dilation,
                    module.groups,
                ).to(module.weight.device)

                new_module.weight = module.weight
                if module.bias is not None:
                    new_module.bias = module.bias

                module_replaced = True

            if module_replaced:
                setattr(model, name, new_module)
                del module

                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
98
    return model
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117


def scale_lora_layers(model, weight):
    """
    Adjust the weightage given to the LoRA layers of the model.

    Args:
        model (`torch.nn.Module`):
            The model to scale.
        weight (`float`):
            The weight to be given to the LoRA layers.
    """
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            module.scale_layer(weight)


118
def unscale_lora_layers(model, weight: Optional[float] = None):
119
120
121
122
123
124
    """
    Removes the previously passed weight given to the LoRA layers of the model.

    Args:
        model (`torch.nn.Module`):
            The model to scale.
125
126
127
128
        weight (`float`, *optional*):
            The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be
            re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct
            value.
129
130
131
132
133
    """
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
134
135
136
137
138
139
            if weight is not None and weight != 0:
                module.unscale_layer(weight)
            elif weight is not None and weight == 0:
                for adapter_name in module.active_adapters:
                    # if weight == 0 unscale should re-set the scale to the original value.
                    module.set_scale(adapter_name, 1.0)
140
141


142
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
143
144
145
    rank_pattern = {}
    alpha_pattern = {}
    r = lora_alpha = list(rank_dict.values())[0]
146

147
148
149
150
151
152
153
154
    if len(set(rank_dict.values())) > 1:
        # get the rank occuring the most number of times
        r = collections.Counter(rank_dict.values()).most_common()[0][0]

        # for modules with rank different from the most occuring rank, add it to the `rank_pattern`
        rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
        rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}

155
    if network_alpha_dict is not None and len(network_alpha_dict) > 0:
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
        if len(set(network_alpha_dict.values())) > 1:
            # get the alpha occuring the most number of times
            lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]

            # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
            alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
            if is_unet:
                alpha_pattern = {
                    ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
                    for k, v in alpha_pattern.items()
                }
            else:
                alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
        else:
            lora_alpha = set(network_alpha_dict.values()).pop()
171
172
173

    # layer names without the Diffusers specific
    target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
174
    use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
175
176
177
178
179
180
181

    lora_config_kwargs = {
        "r": r,
        "lora_alpha": lora_alpha,
        "rank_pattern": rank_pattern,
        "alpha_pattern": alpha_pattern,
        "target_modules": target_modules,
182
        "use_dora": use_dora,
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    }
    return lora_config_kwargs


def get_adapter_name(model):
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return f"default_{len(module.r)}"
    return "default_0"


def set_adapter_layers(model, enabled=True):
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            # The recent version of PEFT needs to call `enable_adapters` instead
            if hasattr(module, "enable_adapters"):
203
                module.enable_adapters(enabled=enabled)
204
            else:
205
                module.disable_adapters = not enabled
206
207


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
def delete_adapter_layers(model, adapter_name):
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            if hasattr(module, "delete_adapter"):
                module.delete_adapter(adapter_name)
            else:
                raise ValueError(
                    "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
                )

    # For transformers integration - we need to pop the adapter from the config
    if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
        model.peft_config.pop(adapter_name, None)
        # In case all adapters are deleted, we need to delete the config
        # and make sure to set the flag to False
        if len(model.peft_config) == 0:
            del model.peft_config
            model._hf_peft_config_loaded = None


230
231
232
def set_weights_and_activate_adapters(model, adapter_names, weights):
    from peft.tuners.tuners_utils import BaseTunerLayer

UmerHA's avatar
UmerHA committed
233
234
235
236
237
238
239
240
241
242
    def get_module_weight(weight_for_adapter, module_name):
        if not isinstance(weight_for_adapter, dict):
            # If weight_for_adapter is a single number, always return it.
            return weight_for_adapter

        for layer_name, weight_ in weight_for_adapter.items():
            if layer_name in module_name:
                return weight_
        raise RuntimeError(f"No LoRA weight found for module {module_name}.")

243
244
    # iterate over each adapter, make it active and set the corresponding scaling weight
    for adapter_name, weight in zip(adapter_names, weights):
UmerHA's avatar
UmerHA committed
245
        for module_name, module in model.named_modules():
246
247
248
249
250
251
            if isinstance(module, BaseTunerLayer):
                # For backward compatbility with previous PEFT versions
                if hasattr(module, "set_adapter"):
                    module.set_adapter(adapter_name)
                else:
                    module.active_adapter = adapter_name
UmerHA's avatar
UmerHA committed
252
                module.set_scale(adapter_name, get_module_weight(weight, module_name))
253
254
255
256
257
258
259
260
261

    # set multiple active adapters
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            # For backward compatbility with previous PEFT versions
            if hasattr(module, "set_adapter"):
                module.set_adapter(adapter_names)
            else:
                module.active_adapter = adapter_names
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281


def check_peft_version(min_version: str) -> None:
    r"""
    Checks if the version of PEFT is compatible.

    Args:
        version (`str`):
            The version of PEFT to check against.
    """
    if not is_peft_available():
        raise ValueError("PEFT is not installed. Please install it with `pip install peft`")

    is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)

    if not is_peft_version_compatible:
        raise ValueError(
            f"The version of PEFT you are using is not compatible, please use a version that is greater"
            f" than {min_version}"
        )