transformer_sd3.py 8.43 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from contextlib import nullcontext
15
16
17
18
from typing import Dict

from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
from ..models.embeddings import IPAdapterTimeImageProjection
19
20
from ..models.model_loading_utils import load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
21
from ..utils import is_accelerate_available, is_torch_version, logging
22
from ..utils.torch_utils import empty_device_cache
23
24
25


logger = logging.get_logger(__name__)
26
27
28
29
30


class SD3Transformer2DLoadersMixin:
    """Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""

31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    def _convert_ip_adapter_attn_to_diffusers(
        self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
    ) -> Dict:
        if low_cpu_mem_usage:
            if is_accelerate_available():
                from accelerate import init_empty_weights

            else:
                low_cpu_mem_usage = False
                logger.warning(
                    "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                    " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                    " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                    " install accelerate\n```\n."
                )

        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )
52
53
54
55

        # IP-Adapter cross attention parameters
        hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
        ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
56
        timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
57
58
59
60

        # Dict where key is transformer layer index, value is attention processor's state dict
        # ip_adapter state dict keys example: "0.norm_ip.linear.weight"
        layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
61
        for key, weights in state_dict.items():
62
63
64
            idx, name = key.split(".", maxsplit=1)
            layer_state_dict[int(idx)][name] = weights

65
        # Create IP-Adapter attention processor & load state_dict
66
        attn_procs = {}
67
        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
68
        for idx, name in enumerate(self.attn_processors.keys()):
69
70
71
72
73
74
75
            with init_context():
                attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
                    hidden_size=hidden_size,
                    ip_hidden_states_dim=ip_hidden_states_dim,
                    head_dim=self.config.attention_head_dim,
                    timesteps_emb_dim=timesteps_emb_dim,
                )
76
77
78
79

            if not low_cpu_mem_usage:
                attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
            else:
80
                device_map = {"": self.device}
81
                load_model_dict_into_meta(
82
                    attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
83
84
                )

85
86
        empty_device_cache()

87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        return attn_procs

    def _convert_ip_adapter_image_proj_to_diffusers(
        self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
    ) -> IPAdapterTimeImageProjection:
        if low_cpu_mem_usage:
            if is_accelerate_available():
                from accelerate import init_empty_weights

            else:
                low_cpu_mem_usage = False
                logger.warning(
                    "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                    " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                    " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                    " install accelerate\n```\n."
                )

        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )

        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext

        # Convert to diffusers
        updated_state_dict = {}
        for key, value in state_dict.items():
            # InstantX/SD3.5-Large-IP-Adapter
            if key.startswith("layers."):
                idx = key.split(".")[1]
                key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
                key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
                key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
                key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
                key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
                key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
                key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
                key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
                key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
            updated_state_dict[key] = value
129

130
        # Image projection parameters
131
132
133
134
135
136
        embed_dim = updated_state_dict["proj_in.weight"].shape[1]
        output_dim = updated_state_dict["proj_out.weight"].shape[0]
        hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
        heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
        num_queries = updated_state_dict["latents"].shape[1]
        timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
137
138

        # Image projection
139
140
141
142
143
144
145
146
147
        with init_context():
            image_proj = IPAdapterTimeImageProjection(
                embed_dim=embed_dim,
                output_dim=output_dim,
                hidden_dim=hidden_dim,
                heads=heads,
                num_queries=num_queries,
                timestep_in_dim=timestep_in_dim,
            )
148
149

        if not low_cpu_mem_usage:
150
            image_proj.load_state_dict(updated_state_dict, strict=True)
151
        else:
152
153
            device_map = {"": self.device}
            load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
154
            empty_device_cache()
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175

        return image_proj

    def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
        """Sets IP-Adapter attention processors, image projection, and loads state_dict.

        Args:
            state_dict (`Dict`):
                State dict with keys "ip_adapter", which contains parameters for attention processors, and
                "image_proj", which contains parameters for image projection net.
            low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
                Speed up model loading only loading the pretrained weights and not initializing the weights. This also
                tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
                Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
                argument to `True` will raise an error.
        """

        attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
        self.set_attn_processor(attn_procs)

        self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)