peft_utils.py 15.6 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PEFT utilities: Utilities related to peft library
"""
17

18
import collections
19
import importlib
20
from typing import Optional
21

22
23
from packaging import version

24
25
from . import logging
from .import_utils import is_peft_available, is_peft_version, is_torch_available
26
from .torch_utils import empty_device_cache
27
28


29
30
logger = logging.get_logger(__name__)

31
32
33
if is_torch_available():
    import torch

34

35
def recurse_remove_peft_layers(model):
36
37
38
    r"""
    Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
    """
39
    from peft.tuners.tuners_utils import BaseTunerLayer
40

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
    has_base_layer_pattern = False
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            has_base_layer_pattern = hasattr(module, "base_layer")
            break

    if has_base_layer_pattern:
        from peft.utils import _get_submodules

        key_list = [key for key, _ in model.named_modules() if "lora" not in key]
        for key in key_list:
            try:
                parent, target, target_name = _get_submodules(model, key)
            except AttributeError:
                continue
            if hasattr(target, "base_layer"):
                setattr(parent, target_name, target.get_base_layer())
    else:
        # This is for backwards compatibility with PEFT <= 0.6.2.
        # TODO can be removed once that PEFT version is no longer supported.
        from peft.tuners.lora import LoraLayer

        for name, module in model.named_children():
            if len(list(module.children())) > 0:
                ## compound module, go inside it
                recurse_remove_peft_layers(module)

            module_replaced = False

            if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
71
72
73
74
75
                new_module = torch.nn.Linear(
                    module.in_features,
                    module.out_features,
                    bias=module.bias is not None,
                ).to(module.weight.device)
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
                new_module.weight = module.weight
                if module.bias is not None:
                    new_module.bias = module.bias

                module_replaced = True
            elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
                new_module = torch.nn.Conv2d(
                    module.in_channels,
                    module.out_channels,
                    module.kernel_size,
                    module.stride,
                    module.padding,
                    module.dilation,
                    module.groups,
                ).to(module.weight.device)

                new_module.weight = module.weight
                if module.bias is not None:
                    new_module.bias = module.bias

                module_replaced = True

            if module_replaced:
                setattr(model, name, new_module)
                del module

102
                empty_device_cache()
103
    return model
104
105
106
107
108
109
110
111
112
113
114
115
116
117


def scale_lora_layers(model, weight):
    """
    Adjust the weightage given to the LoRA layers of the model.

    Args:
        model (`torch.nn.Module`):
            The model to scale.
        weight (`float`):
            The weight to be given to the LoRA layers.
    """
    from peft.tuners.tuners_utils import BaseTunerLayer

118
119
120
    if weight == 1.0:
        return

121
122
123
124
125
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            module.scale_layer(weight)


126
def unscale_lora_layers(model, weight: Optional[float] = None):
127
128
129
130
131
132
    """
    Removes the previously passed weight given to the LoRA layers of the model.

    Args:
        model (`torch.nn.Module`):
            The model to scale.
133
134
135
136
        weight (`float`, *optional*):
            The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be
            re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct
            value.
137
138
139
    """
    from peft.tuners.tuners_utils import BaseTunerLayer

140
    if weight is None or weight == 1.0:
141
142
        return

143
144
    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
145
            if weight != 0:
146
                module.unscale_layer(weight)
147
            else:
148
149
150
                for adapter_name in module.active_adapters:
                    # if weight == 0 unscale should re-set the scale to the original value.
                    module.set_scale(adapter_name, 1.0)
151
152


153
154
155
def get_peft_kwargs(
    rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
):
156
157
158
    rank_pattern = {}
    alpha_pattern = {}
    r = lora_alpha = list(rank_dict.values())[0]
159

160
    if len(set(rank_dict.values())) > 1:
161
        # get the rank occurring the most number of times
162
163
        r = collections.Counter(rank_dict.values()).most_common()[0][0]

164
        # for modules with rank different from the most occurring rank, add it to the `rank_pattern`
165
166
167
        rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
        rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}

168
    if network_alpha_dict is not None and len(network_alpha_dict) > 0:
169
        if len(set(network_alpha_dict.values())) > 1:
170
            # get the alpha occurring the most number of times
171
172
            lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]

173
            # for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
174
175
176
177
178
179
180
181
182
183
            alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
            if is_unet:
                alpha_pattern = {
                    ".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
                    for k, v in alpha_pattern.items()
                }
            else:
                alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
        else:
            lora_alpha = set(network_alpha_dict.values()).pop()
184
185

    target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
186
    use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
Aryan's avatar
Aryan committed
187
188
    # for now we know that the "bias" keys are only associated with `lora_B`.
    lora_bias = any("lora_B" in k and k.endswith(".bias") for k in peft_state_dict)
189
190
191
192
193
194
195

    lora_config_kwargs = {
        "r": r,
        "lora_alpha": lora_alpha,
        "rank_pattern": rank_pattern,
        "alpha_pattern": alpha_pattern,
        "target_modules": target_modules,
196
        "use_dora": use_dora,
Aryan's avatar
Aryan committed
197
        "lora_bias": lora_bias,
198
    }
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213

    # Example: try load FusionX LoRA into Wan VACE
    exclude_modules = _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name)
    if exclude_modules:
        if not is_peft_version(">=", "0.14.0"):
            msg = """
It seems like there are certain modules that need to be excluded when initializing `LoraConfig`. Your current `peft`
version doesn't support passing an `exclude_modules` to `LoraConfig`. Please update it by running `pip install -U
peft`. For most cases, this can be completely ignored. But if it seems unexpected, please file an issue -
https://github.com/huggingface/diffusers/issues/new
            """
            logger.debug(msg)
        else:
            lora_config_kwargs.update({"exclude_modules": exclude_modules})

214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
    return lora_config_kwargs


def get_adapter_name(model):
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            return f"default_{len(module.r)}"
    return "default_0"


def set_adapter_layers(model, enabled=True):
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            # The recent version of PEFT needs to call `enable_adapters` instead
            if hasattr(module, "enable_adapters"):
233
                module.enable_adapters(enabled=enabled)
234
            else:
235
                module.disable_adapters = not enabled
236
237


238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def delete_adapter_layers(model, adapter_name):
    from peft.tuners.tuners_utils import BaseTunerLayer

    for module in model.modules():
        if isinstance(module, BaseTunerLayer):
            if hasattr(module, "delete_adapter"):
                module.delete_adapter(adapter_name)
            else:
                raise ValueError(
                    "The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
                )

    # For transformers integration - we need to pop the adapter from the config
    if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
        model.peft_config.pop(adapter_name, None)
        # In case all adapters are deleted, we need to delete the config
        # and make sure to set the flag to False
        if len(model.peft_config) == 0:
            del model.peft_config
            model._hf_peft_config_loaded = None


260
261
262
def set_weights_and_activate_adapters(model, adapter_names, weights):
    from peft.tuners.tuners_utils import BaseTunerLayer

UmerHA's avatar
UmerHA committed
263
264
265
266
267
268
269
270
    def get_module_weight(weight_for_adapter, module_name):
        if not isinstance(weight_for_adapter, dict):
            # If weight_for_adapter is a single number, always return it.
            return weight_for_adapter

        for layer_name, weight_ in weight_for_adapter.items():
            if layer_name in module_name:
                return weight_
271
272
273
274
275
276
277

        parts = module_name.split(".")
        # e.g. key = "down_blocks.1.attentions.0"
        key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
        block_weight = weight_for_adapter.get(key, 1.0)

        return block_weight
UmerHA's avatar
UmerHA committed
278

Thanh Le's avatar
Thanh Le committed
279
    for module_name, module in model.named_modules():
280
        if isinstance(module, BaseTunerLayer):
Thanh Le's avatar
Thanh Le committed
281
            # For backward compatibility with previous PEFT versions, set multiple active adapters
282
283
284
285
            if hasattr(module, "set_adapter"):
                module.set_adapter(adapter_names)
            else:
                module.active_adapter = adapter_names
286

Thanh Le's avatar
Thanh Le committed
287
288
289
290
            # Set the scaling weight for each adapter for this module
            for adapter_name, weight in zip(adapter_names, weights):
                module.set_scale(adapter_name, get_module_weight(weight, module_name))

291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309

def check_peft_version(min_version: str) -> None:
    r"""
    Checks if the version of PEFT is compatible.

    Args:
        version (`str`):
            The version of PEFT to check against.
    """
    if not is_peft_available():
        raise ValueError("PEFT is not installed. Please install it with `pip install peft`")

    is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)

    if not is_peft_version_compatible:
        raise ValueError(
            f"The version of PEFT you are using is not compatible, please use a version that is greater"
            f" than {min_version}"
        )
310
311
312


def _create_lora_config(
313
    state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
314
315
316
317
318
319
320
):
    from peft import LoraConfig

    if metadata is not None:
        lora_config_kwargs = metadata
    else:
        lora_config_kwargs = get_peft_kwargs(
321
322
323
324
325
326
            rank_pattern_dict,
            network_alpha_dict=network_alphas,
            peft_state_dict=state_dict,
            is_unet=is_unet,
            model_state_dict=model_state_dict,
            adapter_name=adapter_name,
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
        )

    _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)

    # Version checks for DoRA and lora_bias
    if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
        if is_peft_version("<", "0.9.0"):
            raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")

    if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
        if is_peft_version("<=", "0.13.2"):
            raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")

    try:
        return LoraConfig(**lora_config_kwargs)
    except TypeError as e:
        raise TypeError("`LoraConfig` class could not be instantiated.") from e


def _maybe_raise_error_for_ambiguous_keys(config):
    rank_pattern = config["rank_pattern"].copy()
    target_modules = config["target_modules"]

    for key in list(rank_pattern.keys()):
        # try to detect ambiguity
        # `target_modules` can also be a str, in which case this loop would loop
        # over the chars of the str. The technically correct way to match LoRA keys
        # in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
        # But this cuts it for now.
        exact_matches = [mod for mod in target_modules if mod == key]
        substring_matches = [mod for mod in target_modules if key in mod and mod != key]

        if exact_matches and substring_matches:
            if is_peft_version("<", "0.14.1"):
                raise ValueError(
                    "There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
                )


def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
    warn_msg = ""
    if incompatible_keys is not None:
        # Check only for unexpected keys.
        unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
        if unexpected_keys:
            lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
            if lora_unexpected_keys:
                warn_msg = (
                    f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
                    f" {', '.join(lora_unexpected_keys)}. "
                )

        # Filter missing keys specific to the current adapter.
        missing_keys = getattr(incompatible_keys, "missing_keys", None)
        if missing_keys:
            lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
            if lora_missing_keys:
                warn_msg += (
                    f"Loading adapter weights from state_dict led to missing keys in the model:"
                    f" {', '.join(lora_missing_keys)}."
                )

    if warn_msg:
        logger.warning(warn_msg)
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414


def _derive_exclude_modules(model_state_dict, peft_state_dict, adapter_name=None):
    """
    Derives the modules to exclude while initializing `LoraConfig` through `exclude_modules`. It works by comparing the
    `model_state_dict` and `peft_state_dict` and adds a module from `model_state_dict` to the exclusion set if it
    doesn't exist in `peft_state_dict`.
    """
    if model_state_dict is None:
        return
    all_modules = set()
    string_to_replace = f"{adapter_name}." if adapter_name else ""

    for name in model_state_dict.keys():
        if string_to_replace:
            name = name.replace(string_to_replace, "")
        if "." in name:
            module_name = name.rsplit(".", 1)[0]
            all_modules.add(module_name)

    target_modules_set = {name.split(".lora")[0] for name in peft_state_dict.keys()}
    exclude_modules = list(all_modules - target_modules_set)

    return exclude_modules