unet_loader_utils.py 6.07 KB
Newer Older
UmerHA's avatar
UmerHA committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from typing import TYPE_CHECKING, Dict, List, Union

from ..utils import logging


if TYPE_CHECKING:
    # import here to avoid circular imports
    from ..models import UNet2DConditionModel

logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


def _translate_into_actual_layer_name(name):
    """Translate user-friendly name (e.g. 'mid') into actual layer name (e.g. 'mid_block.attentions.0')"""
    if name == "mid":
        return "mid_block.attentions.0"

    updown, block, attn = name.split(".")

    updown = updown.replace("down", "down_blocks").replace("up", "up_blocks")
    block = block.replace("block_", "")
    attn = "attentions." + attn

    return ".".join((updown, block, attn))


Jenyuan-Huang's avatar
Jenyuan-Huang committed
41
42
43
def _maybe_expand_lora_scales(
    unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
UmerHA's avatar
UmerHA committed
44
45
46
47
48
49
50
51
    blocks_with_transformer = {
        "down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
        "up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
    }
    transformer_per_block = {"down": unet.config.layers_per_block, "up": unet.config.layers_per_block + 1}

    expanded_weight_scales = [
        _maybe_expand_lora_scales_for_one_adapter(
Jenyuan-Huang's avatar
Jenyuan-Huang committed
52
53
54
55
56
            weight_for_adapter,
            blocks_with_transformer,
            transformer_per_block,
            unet.state_dict(),
            default_scale=default_scale,
UmerHA's avatar
UmerHA committed
57
58
59
60
61
62
63
64
65
66
67
68
        )
        for weight_for_adapter in weight_scales
    ]

    return expanded_weight_scales


def _maybe_expand_lora_scales_for_one_adapter(
    scales: Union[float, Dict],
    blocks_with_transformer: Dict[str, int],
    transformer_per_block: Dict[str, int],
    state_dict: None,
Jenyuan-Huang's avatar
Jenyuan-Huang committed
69
    default_scale: float = 1.0,
UmerHA's avatar
UmerHA committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
):
    """
    Expands the inputs into a more granular dictionary. See the example below for more details.

    Parameters:
        scales (`Union[float, Dict]`):
            Scales dict to expand.
        blocks_with_transformer (`Dict[str, int]`):
            Dict with keys 'up' and 'down', showing which blocks have transformer layers
        transformer_per_block (`Dict[str, int]`):
            Dict with keys 'up' and 'down', showing how many transformer layers each block has

    E.g. turns
    ```python
84
85
86
    scales = {"down": 2, "mid": 3, "up": {"block_0": 4, "block_1": [5, 6, 7]}}
    blocks_with_transformer = {"down": [1, 2], "up": [0, 1]}
    transformer_per_block = {"down": 2, "up": 3}
UmerHA's avatar
UmerHA committed
87
88
89
90
    ```
    into
    ```python
    {
91
92
93
94
95
96
97
98
99
100
101
        "down.block_1.0": 2,
        "down.block_1.1": 2,
        "down.block_2.0": 2,
        "down.block_2.1": 2,
        "mid": 3,
        "up.block_0.0": 4,
        "up.block_0.1": 4,
        "up.block_0.2": 4,
        "up.block_1.0": 5,
        "up.block_1.1": 6,
        "up.block_1.2": 7,
UmerHA's avatar
UmerHA committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
    }
    ```
    """
    if sorted(blocks_with_transformer.keys()) != ["down", "up"]:
        raise ValueError("blocks_with_transformer needs to be a dict with keys `'down' and `'up'`")

    if sorted(transformer_per_block.keys()) != ["down", "up"]:
        raise ValueError("transformer_per_block needs to be a dict with keys `'down' and `'up'`")

    if not isinstance(scales, dict):
        # don't expand if scales is a single number
        return scales

    scales = copy.deepcopy(scales)

    if "mid" not in scales:
Jenyuan-Huang's avatar
Jenyuan-Huang committed
118
119
120
121
122
123
        scales["mid"] = default_scale
    elif isinstance(scales["mid"], list):
        if len(scales["mid"]) == 1:
            scales["mid"] = scales["mid"][0]
        else:
            raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")
UmerHA's avatar
UmerHA committed
124
125
126

    for updown in ["up", "down"]:
        if updown not in scales:
Jenyuan-Huang's avatar
Jenyuan-Huang committed
127
            scales[updown] = default_scale
UmerHA's avatar
UmerHA committed
128
129
130

        # eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
        if not isinstance(scales[updown], dict):
Jenyuan-Huang's avatar
Jenyuan-Huang committed
131
            scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}
UmerHA's avatar
UmerHA committed
132

Jenyuan-Huang's avatar
Jenyuan-Huang committed
133
        # eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
UmerHA's avatar
UmerHA committed
134
135
        for i in blocks_with_transformer[updown]:
            block = f"block_{i}"
Jenyuan-Huang's avatar
Jenyuan-Huang committed
136
137
138
            # set not assigned blocks to default scale
            if block not in scales[updown]:
                scales[updown][block] = default_scale
UmerHA's avatar
UmerHA committed
139
140
            if not isinstance(scales[updown][block], list):
                scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
Jenyuan-Huang's avatar
Jenyuan-Huang committed
141
142
143
144
145
146
147
            elif len(scales[updown][block]) == 1:
                # a list specifying scale to each masked IP input
                scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
            elif len(scales[updown][block]) != transformer_per_block[updown]:
                raise ValueError(
                    f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
                )
UmerHA's avatar
UmerHA committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

        # eg {"down": "block_1": [1, 1]}}  to {"down.block_1.0": 1, "down.block_1.1": 1}
        for i in blocks_with_transformer[updown]:
            block = f"block_{i}"
            for tf_idx, value in enumerate(scales[updown][block]):
                scales[f"{updown}.{block}.{tf_idx}"] = value

        del scales[updown]

    for layer in scales.keys():
        if not any(_translate_into_actual_layer_name(layer) in module for module in state_dict.keys()):
            raise ValueError(
                f"Can't set lora scale for layer {layer}. It either doesn't exist in this unet or it has no attentions."
            )

    return {_translate_into_actual_layer_name(name): weight for name, weight in scales.items()}