__init__.py 4.91 KB
Newer Older
1
2
3
from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, deprecate
4
from ..utils.import_utils import is_peft_available, is_torch_available, is_transformers_available
5
6
7
8
9
10


def text_encoder_lora_state_dict(text_encoder):
    deprecate(
        "text_encoder_load_state_dict in `models`",
        "0.27.0",
11
        "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
    )
    state_dict = {}

    for name, module in text_encoder_attn_modules(text_encoder):
        for k, v in module.q_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v

        for k, v in module.k_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v

        for k, v in module.v_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v

        for k, v in module.out_proj.lora_linear_layer.state_dict().items():
            state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v

    return state_dict


if is_transformers_available():

    def text_encoder_attn_modules(text_encoder):
        deprecate(
            "text_encoder_attn_modules in `models`",
            "0.27.0",
37
            "`text_encoder_lora_state_dict` is deprecated and will be removed in 0.27.0. Make sure to retrieve the weights using `get_peft_model`. See https://huggingface.co/docs/peft/v0.6.2/en/quicktour#peftmodel for more information.",
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
        )
        from transformers import CLIPTextModel, CLIPTextModelWithProjection

        attn_modules = []

        if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
            for i, layer in enumerate(text_encoder.text_model.encoder.layers):
                name = f"text_model.encoder.layers.{i}.self_attn"
                mod = layer.self_attn
                attn_modules.append((name, mod))
        else:
            raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")

        return attn_modules


_import_structure = {}

if is_torch_available():
57
    _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
hlky's avatar
hlky committed
58
    _import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
59
    _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
60
61
62
    _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
    _import_structure["utils"] = ["AttnProcsLayers"]
    if is_transformers_available():
63
        _import_structure["single_file"] = ["FromSingleFileMixin"]
64
65
66
67
68
        _import_structure["lora_pipeline"] = [
            "AmusedLoraLoaderMixin",
            "StableDiffusionLoraLoaderMixin",
            "SD3LoraLoaderMixin",
            "StableDiffusionXLLoraLoaderMixin",
Aryan's avatar
Aryan committed
69
            "LTXVideoLoraLoaderMixin",
70
            "LoraLoaderMixin",
Sayak Paul's avatar
Sayak Paul committed
71
            "FluxLoraLoaderMixin",
Aryan's avatar
Aryan committed
72
            "CogVideoXLoraLoaderMixin",
73
            "Mochi1LoraLoaderMixin",
74
            "HunyuanVideoLoraLoaderMixin",
75
            "SanaLoraLoaderMixin",
76
            "Lumina2LoraLoaderMixin",
Aryan's avatar
Aryan committed
77
            "WanLoraLoaderMixin",
78
        ]
79
        _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
80
81
        _import_structure["ip_adapter"] = [
            "IPAdapterMixin",
hlky's avatar
hlky committed
82
            "FluxIPAdapterMixin",
83
84
            "SD3IPAdapterMixin",
        ]
85

86
87
_import_structure["peft"] = ["PeftAdapterMixin"]

88
89
90

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
    if is_torch_available():
91
        from .single_file_model import FromOriginalModelMixin
hlky's avatar
hlky committed
92
        from .transformer_flux import FluxTransformer2DLoadersMixin
93
        from .transformer_sd3 import SD3Transformer2DLoadersMixin
94
95
96
97
        from .unet import UNet2DConditionLoadersMixin
        from .utils import AttnProcsLayers

        if is_transformers_available():
98
            from .ip_adapter import (
hlky's avatar
hlky committed
99
                FluxIPAdapterMixin,
100
101
102
                IPAdapterMixin,
                SD3IPAdapterMixin,
            )
103
104
            from .lora_pipeline import (
                AmusedLoraLoaderMixin,
Aryan's avatar
Aryan committed
105
                CogVideoXLoraLoaderMixin,
Sayak Paul's avatar
Sayak Paul committed
106
                FluxLoraLoaderMixin,
107
                HunyuanVideoLoraLoaderMixin,
108
                LoraLoaderMixin,
Aryan's avatar
Aryan committed
109
                LTXVideoLoraLoaderMixin,
110
                Lumina2LoraLoaderMixin,
111
                Mochi1LoraLoaderMixin,
112
                SanaLoraLoaderMixin,
113
114
115
                SD3LoraLoaderMixin,
                StableDiffusionLoraLoaderMixin,
                StableDiffusionXLLoraLoaderMixin,
Aryan's avatar
Aryan committed
116
                WanLoraLoaderMixin,
117
            )
118
119
            from .single_file import FromSingleFileMixin
            from .textual_inversion import TextualInversionLoaderMixin
120
121

    from .peft import PeftAdapterMixin
122
123
124
125
else:
    import sys

    sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)