unet_loader_utils.py 6.12 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
UmerHA's avatar
UmerHA committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Dict, List, Union

17
18
from torch import nn

UmerHA's avatar
UmerHA committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from ..utils import logging


if TYPE_CHECKING:
    # import here to avoid circular imports
    from ..models import UNet2DConditionModel

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def _translate_into_actual_layer_name(name):
    """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
    if name == "mid":
        return "mid_block.attentions.0"

    updown, block, attn = name.split(".")

    updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
    block = block.replace("block_", "")
    attn = "attentions." + attn

    return ".".join((updown, block, attn))


Jenyuan-Huang's avatar
Jenyuan-Huang committed
43
44
45
def _maybe_expand_lora_scales(
    unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
UmerHA's avatar
UmerHA committed
46
47
48
49
50
51
52
53
    blocks_with_transformer = {
        "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
        "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
    }
    transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}

    expanded_weight_scales = [
        _maybe_expand_lora_scales_for_one_adapter(
Jenyuan-Huang's avatar
Jenyuan-Huang committed
54
55
56
            weight_for_adapter,
            blocks_with_transformer,
            transformer_per_block,
57
            model=unet,
Jenyuan-Huang's avatar
Jenyuan-Huang committed
58
            default_scale=default_scale,
UmerHA's avatar
UmerHA committed
59
60
61
62
63
64
65
66
67
68
69
        )
        for weight_for_adapter in weight_scales
    ]

    return expanded_weight_scales


def _maybe_expand_lora_scales_for_one_adapter(
    scales: Union[float, Dict],
    blocks_with_transformer: Dict[str, int],
    transformer_per_block: Dict[str, int],
70
    model: nn.Module,
Jenyuan-Huang's avatar
Jenyuan-Huang committed
71
    default_scale: float = 1.0,
UmerHA's avatar
UmerHA committed
72
73
74
75
76
77
78
79
80
81
82
83
84
85
):
    """
    Expands the inputs into a more granular dictionary. See the example below for more details.

    Parameters:
        scales (`Union[float, Dict]`):
            Scales dict to expand.
        blocks_with_transformer (`Dict[str, int]`):
            Dict with keys 'up' and 'down', showing which blocks have transformer layers
        transformer_per_block (`Dict[str, int]`):
            Dict with keys 'up' and 'down', showing how many transformer layers each block has

    E.g. turns
    ```python
86
87
88
    scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
    blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
    transformer_per_block = {"down": 2, "up": 3}
UmerHA's avatar
UmerHA committed
89
90
91
92
    ```
    into
    ```python
    {
93
94
95
96
97
98
99
100
101
102
103
        "down.block_1.0": 2,
        "down.block_1.1": 2,
        "down.block_2.0": 2,
        "down.block_2.1": 2,
        "mid": 3,
        "up.block_0.0": 4,
        "up.block_0.1": 4,
        "up.block_0.2": 4,
        "up.block_1.0": 5,
        "up.block_1.1": 6,
        "up.block_1.2": 7,
UmerHA's avatar
UmerHA committed
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
    }
    ```
    """
    if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
        raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")

    if sorted(transformer_per_block.keys()) != ["down", "up"]:
        raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")

    if not isinstance(scales, dict):
        # don't expand if scales is a single number
        return scales

    scales = copy.deepcopy(scales)

    if "mid" not in scales:
Jenyuan-Huang's avatar
Jenyuan-Huang committed
120
121
122
123
124
125
        scales["mid"] = default_scale
    elif isinstance(scales["mid"], list):
        if len(scales["mid"]) == 1:
            scales["mid"] = scales["mid"][0]
        else:
            raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
UmerHA's avatar
UmerHA committed
126
127
128

    for updown in ["up", "down"]:
        if updown not in scales:
Jenyuan-Huang's avatar
Jenyuan-Huang committed
129
            scales[updown] = default_scale
UmerHA's avatar
UmerHA committed
130
131
132

        # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
        if not isinstance(scales[updown], dict):
Jenyuan-Huang's avatar
Jenyuan-Huang committed
133
            scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
UmerHA's avatar
UmerHA committed
134

Jenyuan-Huang's avatar
Jenyuan-Huang committed
135
        # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
UmerHA's avatar
UmerHA committed
136
137
        for i in blocks_with_transformer[updown]:
            block = f"block_{i}"
Jenyuan-Huang's avatar
Jenyuan-Huang committed
138
139
140
            # set not assigned blocks to default scale
            if block not in scales[updown]:
                scales[updown][block] = default_scale
UmerHA's avatar
UmerHA committed
141
142
            if not isinstance(scales[updown][block], list):
                scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
Jenyuan-Huang's avatar
Jenyuan-Huang committed
143
144
145
146
147
148
149
            elif len(scales[updown][block]) == 1:
                # a list specifying scale to each masked IP input
                scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
            elif len(scales[updown][block]) != transformer_per_block[updown]:
                raise ValueError(
                    f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
                )
UmerHA's avatar
UmerHA committed
150
151
152
153
154
155
156
157
158

        # eg {"down": "block_1": [1, 1]}}  to {"down.block_1.0": 1, "down.block_1.1": 1}
        for i in blocks_with_transformer[updown]:
            block = f"block_{i}"
            for tf_idx, value in enumerate(scales[updown][block]):
                scales[f"{updown}.{block}.{tf_idx}"] = value

        del scales[updown]

159
    state_dict = model.state_dict()
UmerHA's avatar
UmerHA committed
160
161
162
163
164
165
166
    for layer in scales.keys():
        if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
            raise ValueError(
                f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
            )

    return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}