__init__.py 5.62 KB
Newer Older
1
2
3
from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
4
from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available
5
6
7
8
9
10


def text_encoder_lora_state_dict(text_encoder):
    deprecate(
        "text_encoder_load_state_dict in `models`",
        "0.27.0",
11
        "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    )
    state_dict = {}

    for name, module in text_encoder_attn_modules(text_encoder):
        for k, v in module.q_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

        for k, v in module.k_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

        for k, v in module.v_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

        for k, v in module.out_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

    return state_dict


if is_transformers_available():

    def text_encoder_attn_modules(text_encoder):
        deprecate(
            "text_encoder_attn_modules in `models`",
            "0.27.0",
37
            "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        )
        from transformers import CLIPTextModel, CLIPTextModelWithProjection

        attn_modules = []

        if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
            for i, layer in enumerate(text_encoder.text_model.encoder.layers):
                name = f"text_model.encoder.layers.{i}.self_attn"
                mod = layer.self_attn
                attn_modules.append((name, mod))
        else:
            raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")

        return attn_modules


_import_structure = {}

if is_torch_available():
57
    _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
hlky's avatar
hlky committed
58
    _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
59
    _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
60
61
62
    _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
    _import_structure["utils"] = ["AttnProcsLayers"]
    if is_transformers_available():
63
        _import_structure["single_file"] = ["FromSingleFileMixin"]
64
65
66
67
        _import_structure["lora_pipeline"] = [
            "AmusedLoraLoaderMixin",
            "StableDiffusionLoraLoaderMixin",
            "SD3LoraLoaderMixin",
68
            "AuraFlowLoraLoaderMixin",
69
            "StableDiffusionXLLoraLoaderMixin",
Aryan's avatar
Aryan committed
70
            "LTXVideoLoraLoaderMixin",
71
            "LoraLoaderMixin",
Sayak Paul's avatar
Sayak Paul committed
72
            "FluxLoraLoaderMixin",
Aryan's avatar
Aryan committed
73
            "CogVideoXLoraLoaderMixin",
Aryan's avatar
Aryan committed
74
            "CogView4LoraLoaderMixin",
75
            "Mochi1LoraLoaderMixin",
76
            "HunyuanVideoLoraLoaderMixin",
77
            "SanaLoraLoaderMixin",
78
            "Lumina2LoraLoaderMixin",
Aryan's avatar
Aryan committed
79
            "WanLoraLoaderMixin",
80
            "KandinskyLoraLoaderMixin",
81
            "HiDreamImageLoraLoaderMixin",
82
            "SkyReelsV2LoraLoaderMixin",
83
            "QwenImageLoraLoaderMixin",
84
            "ZImageLoraLoaderMixin",
Sayak Paul's avatar
Sayak Paul committed
85
            "Flux2LoraLoaderMixin",
86
        ]
87
        _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
88
89
        _import_structure["ip_adapter"] = [
            "IPAdapterMixin",
hlky's avatar
hlky committed
90
            "FluxIPAdapterMixin",
91
            "SD3IPAdapterMixin",
YiYi Xu's avatar
YiYi Xu committed
92
            "ModularIPAdapterMixin",
93
        ]
94

95
96
_import_structure["peft"] = ["PeftAdapterMixin"]

97
98
99

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    if is_torch_available():
100
        from .single_file_model import FromOriginalModelMixin
hlky's avatar
hlky committed
101
        from .transformer_flux import FluxTransformer2DLoadersMixin
102
        from .transformer_sd3 import SD3Transformer2DLoadersMixin
103
104
105
106
        from .unet import UNet2DConditionLoadersMixin
        from .utils import AttnProcsLayers

        if is_transformers_available():
107
            from .ip_adapter import (
hlky's avatar
hlky committed
108
                FluxIPAdapterMixin,
109
                IPAdapterMixin,
YiYi Xu's avatar
YiYi Xu committed
110
                ModularIPAdapterMixin,
111
112
                SD3IPAdapterMixin,
            )
113
114
            from .lora_pipeline import (
                AmusedLoraLoaderMixin,
115
                AuraFlowLoraLoaderMixin,
Aryan's avatar
Aryan committed
116
                CogVideoXLoraLoaderMixin,
Aryan's avatar
Aryan committed
117
                CogView4LoraLoaderMixin,
Sayak Paul's avatar
Sayak Paul committed
118
                Flux2LoraLoaderMixin,
Sayak Paul's avatar
Sayak Paul committed
119
                FluxLoraLoaderMixin,
120
                HiDreamImageLoraLoaderMixin,
121
                HunyuanVideoLoraLoaderMixin,
122
                KandinskyLoraLoaderMixin,
123
                LoraLoaderMixin,
Aryan's avatar
Aryan committed
124
                LTXVideoLoraLoaderMixin,
125
                Lumina2LoraLoaderMixin,
126
                Mochi1LoraLoaderMixin,
127
                QwenImageLoraLoaderMixin,
128
                SanaLoraLoaderMixin,
129
                SD3LoraLoaderMixin,
130
                SkyReelsV2LoraLoaderMixin,
131
132
                StableDiffusionLoraLoaderMixin,
                StableDiffusionXLLoraLoaderMixin,
Aryan's avatar
Aryan committed
133
                WanLoraLoaderMixin,
134
                ZImageLoraLoaderMixin,
135
            )
136
137
            from .single_file import FromSingleFileMixin
            from .textual_inversion import TextualInversionLoaderMixin
138
139

    from .peft import PeftAdapterMixin
140
141
142
143
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)