unet.py 44.7 KB
Newer Older
Aryan's avatar
Aryan committed
1
# Copyright 2025 The HuggingFace Team. All rights reserved.
2
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
15
from collections import defaultdict
16
from contextlib import nullcontext
gzguevara's avatar
gzguevara committed
17
from pathlib import Path
18
from typing import Callable, Dict, Union
19
20
21

import safetensors
import torch
22
import torch.nn.functional as F
23
from huggingface_hub.utils import validate_hf_hub_args
24

25
26
from ..models.embeddings import (
    ImageProjection,
27
28
    IPAdapterFaceIDImageProjection,
    IPAdapterFaceIDPlusImageProjection,
29
30
31
32
    IPAdapterFullImageProjection,
    IPAdapterPlusImageProjection,
    MultiIPAdapterImageProjection,
)
33
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
34
35
36
from ..utils import (
    USE_PEFT_BACKEND,
    _get_model_file,
37
    convert_unet_state_dict_to_peft,
38
    deprecate,
39
40
    get_adapter_name,
    get_peft_kwargs,
41
    is_accelerate_available,
42
    is_peft_version,
43
    is_torch_version,
44
45
    logging,
)
46
from ..utils.torch_utils import empty_device_cache
47
from .lora_base import _func_optionally_disable_offloading
48
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
49
50
51
52
53
54
55
56
57
58
59
from .utils import AttnProcsLayers


logger = logging.get_logger(__name__)


CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"


class UNet2DConditionLoadersMixin:
Steven Liu's avatar
Steven Liu committed
60
61
62
63
    """
    Load LoRA layers into a [`UNet2DCondtionModel`].
    """

64
65
66
    text_encoder_name = TEXT_ENCODER_NAME
    unet_name = UNET_NAME

67
    @validate_hf_hub_args
68
69
70
71
72
    def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
        r"""
        Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
        defined in
        [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
73
74
        and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
        `peft`: `pip install -U peft`.
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

        Parameters:
            pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
                Can be either:

                    - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
                      the Hub.
                    - A path to a directory (for example `./my_model_directory`) containing the model weights saved
                      with [`ModelMixin.save_pretrained`].
                    - A [torch state
                      dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).

            cache_dir (`Union[str, os.PathLike]`, *optional*):
                Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
                is not used.
            force_download (`bool`, *optional*, defaults to `False`):
                Whether or not to force the (re-)download of the model weights and configuration files, overriding the
                cached versions if they exist.
93

94
95
96
97
98
99
            proxies (`Dict[str, str]`, *optional*):
                A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
                'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
            local_files_only (`bool`, *optional*, defaults to `False`):
                Whether to only load local model weights and configuration files or not. If set to `True`, the model
                won't be downloaded from the Hub.
100
            token (`str` or *bool*, *optional*):
101
102
103
104
105
106
107
                The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
                `diffusers-cli login` (stored in `~/.huggingface`) is used.
            revision (`str`, *optional*, defaults to `"main"`):
                The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
                allowed by Git.
            subfolder (`str`, *optional*, defaults to `""`):
                The subfolder location of a model file within a larger model repository on the Hub or locally.
108
109
110
111
112
113
114
115
116
            network_alphas (`Dict[str, float]`):
                The value of the network alpha used for stable learning and preventing underflow. This value has the
                same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
                link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
            adapter_name (`str`, *optional*, defaults to None):
                Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
                `default_{i}` where i is the total number of adapters being loaded.
            weight_name (`str`, *optional*, defaults to None):
                Name of the serialized state dict file.
117
118
119
            low_cpu_mem_usage (`bool`, *optional*):
                Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
                weights.
120

Steven Liu's avatar
Steven Liu committed
121
122
123
124
125
126
127
128
129
130
131
132
133
        Example:

        ```py
        from diffusers import AutoPipelineForText2Image
        import torch

        pipeline = AutoPipelineForText2Image.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
        ).to("cuda")
        pipeline.unet.load_attn_procs(
            "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
        )
        ```
134
        """
YiYi Xu's avatar
YiYi Xu committed
135
136
        from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading

137
        cache_dir = kwargs.pop("cache_dir", None)
138
139
        force_download = kwargs.pop("force_download", False)
        proxies = kwargs.pop("proxies", None)
140
141
        local_files_only = kwargs.pop("local_files_only", None)
        token = kwargs.pop("token", None)
142
143
144
145
        revision = kwargs.pop("revision", None)
        subfolder = kwargs.pop("subfolder", None)
        weight_name = kwargs.pop("weight_name", None)
        use_safetensors = kwargs.pop("use_safetensors", None)
146
        adapter_name = kwargs.pop("adapter_name", None)
147
        _pipeline = kwargs.pop("_pipeline", None)
148
        network_alphas = kwargs.pop("network_alphas", None)
149
        low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
150
151
        allow_pickle = False

152
153
154
155
156
        if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
            raise ValueError(
                "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
            )

157
158
159
160
        if use_safetensors is None:
            use_safetensors = True
            allow_pickle = True

161
        user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176

        model_file = None
        if not isinstance(pretrained_model_name_or_path_or_dict, dict):
            # Let's first try to load .safetensors weights
            if (use_safetensors and weight_name is None) or (
                weight_name is not None and weight_name.endswith(".safetensors")
            ):
                try:
                    model_file = _get_model_file(
                        pretrained_model_name_or_path_or_dict,
                        weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
                        cache_dir=cache_dir,
                        force_download=force_download,
                        proxies=proxies,
                        local_files_only=local_files_only,
177
                        token=token,
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
                        revision=revision,
                        subfolder=subfolder,
                        user_agent=user_agent,
                    )
                    state_dict = safetensors.torch.load_file(model_file, device="cpu")
                except IOError as e:
                    if not allow_pickle:
                        raise e
                    # try loading non-safetensors weights
                    pass
            if model_file is None:
                model_file = _get_model_file(
                    pretrained_model_name_or_path_or_dict,
                    weights_name=weight_name or LORA_WEIGHT_NAME,
                    cache_dir=cache_dir,
                    force_download=force_download,
                    proxies=proxies,
                    local_files_only=local_files_only,
196
                    token=token,
197
198
199
200
                    revision=revision,
                    subfolder=subfolder,
                    user_agent=user_agent,
                )
201
                state_dict = load_state_dict(model_file)
202
203
204
205
        else:
            state_dict = pretrained_model_name_or_path_or_dict

        is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
206
207
208
        is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
        is_model_cpu_offload = False
        is_sequential_cpu_offload = False
209
        is_group_offload = False
210

211
212
213
214
        if is_lora:
            deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
            deprecate("load_attn_procs", "0.40.0", deprecation_message)

215
216
217
        if is_custom_diffusion:
            attn_processors = self._process_custom_diffusion(state_dict=state_dict)
        elif is_lora:
218
            is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
219
220
221
222
223
                state_dict=state_dict,
                unet_identifier_key=self.unet_name,
                network_alphas=network_alphas,
                adapter_name=adapter_name,
                _pipeline=_pipeline,
224
                low_cpu_mem_usage=low_cpu_mem_usage,
225
226
227
228
229
            )
        else:
            raise ValueError(
                f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
            )
230

231
232
233
        # <Unsafe code
        # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
        # Now we remove any existing hooks to `_pipeline`.
234

235
236
        # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
        if is_custom_diffusion and _pipeline is not None:
237
238
239
            is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
                _pipeline=_pipeline
            )
240

241
242
243
            # only custom diffusion needs to set attn processors
            self.set_attn_processor(attn_processors)
            self.to(dtype=self.dtype, device=self.device)
244

245
246
247
248
249
        # Offload back.
        if is_model_cpu_offload:
            _pipeline.enable_model_cpu_offload()
        elif is_sequential_cpu_offload:
            _pipeline.enable_sequential_cpu_offload()
250
251
252
253
        elif is_group_offload:
            for component in _pipeline.components.values():
                if isinstance(component, torch.nn.Module):
                    _maybe_remove_and_reapply_group_offloading(component)
254
        # Unsafe code />
255

256
257
258
259
260
261
262
263
264
265
266
    def _process_custom_diffusion(self, state_dict):
        from ..models.attention_processor import CustomDiffusionAttnProcessor

        attn_processors = {}
        custom_diffusion_grouped_dict = defaultdict(dict)
        for key, value in state_dict.items():
            if len(value) == 0:
                custom_diffusion_grouped_dict[key] = {}
            else:
                if "to_out" in key:
                    attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
267
                else:
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
                    attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
                custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value

        for key, value_dict in custom_diffusion_grouped_dict.items():
            if len(value_dict) == 0:
                attn_processors[key] = CustomDiffusionAttnProcessor(
                    train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
                )
            else:
                cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
                hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
                train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
                attn_processors[key] = CustomDiffusionAttnProcessor(
                    train_kv=True,
                    train_q_out=train_q_out,
                    hidden_size=hidden_size,
                    cross_attention_dim=cross_attention_dim,
                )
                attn_processors[key].load_state_dict(value_dict)

        return attn_processors

290
291
292
    def _process_lora(
        self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline, low_cpu_mem_usage
    ):
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        # This method does the following things:
        # 1. Filters the `state_dict` with keys matching  `unet_identifier_key` when using the non-legacy
        #    format. For legacy format no filtering is applied.
        # 2. Converts the `state_dict` to the `peft` compatible format.
        # 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
        #    `LoraConfig` specs.
        # 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
        if not USE_PEFT_BACKEND:
            raise ValueError("PEFT backend is required for this method.")

        from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict

        keys = list(state_dict.keys())

        unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
        unet_state_dict = {
            k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
        }

        if network_alphas is not None:
            alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
            network_alphas = {
                k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
            }
317
318
319

        is_model_cpu_offload = False
        is_sequential_cpu_offload = False
320
        is_group_offload = False
321
        state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
322

323
324
325
326
327
        if len(state_dict_to_be_used) > 0:
            if adapter_name in getattr(self, "peft_config", {}):
                raise ValueError(
                    f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
                )
328

329
            state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
330

331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
            if network_alphas is not None:
                # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
                # `convert_unet_state_dict_to_peft` method.
                network_alphas = convert_unet_state_dict_to_peft(network_alphas)

            rank = {}
            for key, val in state_dict.items():
                if "lora_B" in key:
                    rank[key] = val.shape[1]

            lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
            if "use_dora" in lora_config_kwargs:
                if lora_config_kwargs["use_dora"]:
                    if is_peft_version("<", "0.9.0"):
                        raise ValueError(
                            "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
                        )
                else:
                    if is_peft_version("<", "0.9.0"):
                        lora_config_kwargs.pop("use_dora")
351
352
353
354
355
356
357
358
359
360
361

            if "lora_bias" in lora_config_kwargs:
                if lora_config_kwargs["lora_bias"]:
                    if is_peft_version("<=", "0.13.2"):
                        raise ValueError(
                            "You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
                        )
                else:
                    if is_peft_version("<=", "0.13.2"):
                        lora_config_kwargs.pop("lora_bias")

362
363
364
365
366
367
368
369
            lora_config = LoraConfig(**lora_config_kwargs)

            # adapter_name
            if adapter_name is None:
                adapter_name = get_adapter_name(self)

            # In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
            # otherwise loading LoRA weights will lead to an error
370
371
372
            is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
                _pipeline
            )
373
374
375
            peft_kwargs = {}
            if is_peft_version(">=", "0.13.1"):
                peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
376

377
378
            inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
            incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
379

380
            warn_msg = ""
381
            if incompatible_keys is not None:
382
                # Check only for unexpected keys.
383
384
                unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
                if unexpected_keys:
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
                    lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
                    if lora_unexpected_keys:
                        warn_msg = (
                            f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
                            f" {', '.join(lora_unexpected_keys)}. "
                        )

                # Filter missing keys specific to the current adapter.
                missing_keys = getattr(incompatible_keys, "missing_keys", None)
                if missing_keys:
                    lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
                    if lora_missing_keys:
                        warn_msg += (
                            f"Loading adapter weights from state_dict led to missing keys in the model:"
                            f" {', '.join(lora_missing_keys)}."
                        )

            if warn_msg:
                logger.warning(warn_msg)
404

405
        return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
406

407
    @classmethod
408
    # Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
409
    def _optionally_disable_offloading(cls, _pipeline):
410
        return _func_optionally_disable_offloading(_pipeline=_pipeline)
411
412
413
414
415
416
417
418
419
420
421

    def save_attn_procs(
        self,
        save_directory: Union[str, os.PathLike],
        is_main_process: bool = True,
        weight_name: str = None,
        save_function: Callable = None,
        safe_serialization: bool = True,
        **kwargs,
    ):
        r"""
Steven Liu's avatar
Steven Liu committed
422
        Save attention processor layers to a directory so that it can be reloaded with the
423
424
425
426
        [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.

        Arguments:
            save_directory (`str` or `os.PathLike`):
Steven Liu's avatar
Steven Liu committed
427
                Directory to save an attention processor to (will be created if it doesn't exist).
428
429
430
431
432
433
434
435
436
            is_main_process (`bool`, *optional*, defaults to `True`):
                Whether the process calling this is the main process or not. Useful during distributed training and you
                need to call this function on all processes. In this case, set `is_main_process=True` only on the main
                process to avoid race conditions.
            save_function (`Callable`):
                The function to use to save the state dictionary. Useful during distributed training when you need to
                replace `torch.save` with another method. Can be configured with the environment variable
                `DIFFUSERS_SAVE_MODE`.
            safe_serialization (`bool`, *optional*, defaults to `True`):
Steven Liu's avatar
Steven Liu committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
                Whether to save the model using `safetensors` or with `pickle`.

        Example:

        ```py
        import torch
        from diffusers import DiffusionPipeline

        pipeline = DiffusionPipeline.from_pretrained(
            "CompVis/stable-diffusion-v1-4",
            torch_dtype=torch.float16,
        ).to("cuda")
        pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
        pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
        ```
452
453
454
455
456
457
458
459
460
461
462
        """
        from ..models.attention_processor import (
            CustomDiffusionAttnProcessor,
            CustomDiffusionAttnProcessor2_0,
            CustomDiffusionXFormersAttnProcessor,
        )

        if os.path.isfile(save_directory):
            logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
            return

463
464
465
466
467
468
469
470
471
        is_custom_diffusion = any(
            isinstance(
                x,
                (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
            )
            for (_, x) in self.attn_processors.items()
        )
        if is_custom_diffusion:
            state_dict = self._get_custom_diffusion_state_dict()
472
473
474
475
476
477
478
479
480
            if save_function is None and safe_serialization:
                # safetensors does not support saving dicts with non-tensor values
                empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)}
                if len(empty_state_dict) > 0:
                    logger.warning(
                        f"Safetensors does not support saving dicts with non-tensor values. "
                        f"The following keys will be ignored: {empty_state_dict.keys()}"
                    )
                state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
481
        else:
482
483
484
            deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
            deprecate("save_attn_procs", "0.40.0", deprecation_message)

485
486
487
488
489
490
491
            if not USE_PEFT_BACKEND:
                raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")

            from peft.utils import get_peft_model_state_dict

            state_dict = get_peft_model_state_dict(self)

492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
        if save_function is None:
            if safe_serialization:

                def save_function(weights, filename):
                    return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})

            else:
                save_function = torch.save

        os.makedirs(save_directory, exist_ok=True)

        if weight_name is None:
            if safe_serialization:
                weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
            else:
                weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME

        # Save the model
gzguevara's avatar
gzguevara committed
510
511
512
        save_path = Path(save_directory, weight_name).as_posix()
        save_function(state_dict, save_path)
        logger.info(f"Model weights saved in {save_path}")
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
    def _get_custom_diffusion_state_dict(self):
        from ..models.attention_processor import (
            CustomDiffusionAttnProcessor,
            CustomDiffusionAttnProcessor2_0,
            CustomDiffusionXFormersAttnProcessor,
        )

        model_to_save = AttnProcsLayers(
            {
                y: x
                for (y, x) in self.attn_processors.items()
                if isinstance(
                    x,
                    (
                        CustomDiffusionAttnProcessor,
                        CustomDiffusionAttnProcessor2_0,
                        CustomDiffusionXFormersAttnProcessor,
                    ),
                )
            }
        )
        state_dict = model_to_save.state_dict()
        for name, attn in self.attn_processors.items():
            if len(attn.state_dict()) == 0:
                state_dict[name] = {}

        return state_dict

542
    def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        if low_cpu_mem_usage:
            if is_accelerate_available():
                from accelerate import init_empty_weights

            else:
                low_cpu_mem_usage = False
                logger.warning(
                    "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                    " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                    " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                    " install accelerate\n```\n."
                )

        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )

562
563
        updated_state_dict = {}
        image_projection = None
564
        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
565
566
567
568
569
570
571

        if "proj.weight" in state_dict:
            # IP-Adapter
            num_image_text_embeds = 4
            clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
            cross_attention_dim = state_dict["proj.weight"].shape[0] // 4

572
573
574
575
576
577
            with init_context():
                image_projection = ImageProjection(
                    cross_attention_dim=cross_attention_dim,
                    image_embed_dim=clip_embeddings_dim,
                    num_image_text_embeds=num_image_text_embeds,
                )
578
579
580
581
582
583
584
585
586
587

            for key, value in state_dict.items():
                diffusers_name = key.replace("proj", "image_embeds")
                updated_state_dict[diffusers_name] = value

        elif "proj.3.weight" in state_dict:
            # IP-Adapter Full
            clip_embeddings_dim = state_dict["proj.0.weight"].shape[0]
            cross_attention_dim = state_dict["proj.3.weight"].shape[0]

588
589
590
591
            with init_context():
                image_projection = IPAdapterFullImageProjection(
                    cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim
                )
592
593
594
595
596
597
598

            for key, value in state_dict.items():
                diffusers_name = key.replace("proj.0", "ff.net.0.proj")
                diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
                diffusers_name = diffusers_name.replace("proj.3", "norm")
                updated_state_dict[diffusers_name] = value

599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
        elif "perceiver_resampler.proj_in.weight" in state_dict:
            # IP-Adapter Face ID Plus
            id_embeddings_dim = state_dict["proj.0.weight"].shape[1]
            embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0]
            hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1]
            output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0]
            heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64

            with init_context():
                image_projection = IPAdapterFaceIDPlusImageProjection(
                    embed_dims=embed_dims,
                    output_dims=output_dims,
                    hidden_dims=hidden_dims,
                    heads=heads,
                    id_embeddings_dim=id_embeddings_dim,
                )

            for key, value in state_dict.items():
                diffusers_name = key.replace("perceiver_resampler.", "")
                diffusers_name = diffusers_name.replace("0.to", "attn.to")
                diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.")
                diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight")
                diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight")
                diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.")
                diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight")
                diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight")
                diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.")
                diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight")
                diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight")
                diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.")
                diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight")
                diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight")
                diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0")
                diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1")
                diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0")
                diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1")
                diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0")
                diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1")
                diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0")
                diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1")

                if "norm1" in diffusers_name:
                    updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value
                elif "norm2" in diffusers_name:
                    updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value
                elif "to_kv" in diffusers_name:
                    v_chunk = value.chunk(2, dim=0)
                    updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
                    updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
                elif "to_out" in diffusers_name:
                    updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
                elif "proj.0.weight" == diffusers_name:
                    updated_state_dict["proj.net.0.proj.weight"] = value
                elif "proj.0.bias" == diffusers_name:
                    updated_state_dict["proj.net.0.proj.bias"] = value
                elif "proj.2.weight" == diffusers_name:
                    updated_state_dict["proj.net.2.weight"] = value
                elif "proj.2.bias" == diffusers_name:
                    updated_state_dict["proj.net.2.bias"] = value
                else:
                    updated_state_dict[diffusers_name] = value

        elif "norm.weight" in state_dict:
            # IP-Adapter Face ID
            id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1]
            id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0]
            multiplier = id_embeddings_dim_out // id_embeddings_dim_in
            norm_layer = "norm.weight"
            cross_attention_dim = state_dict[norm_layer].shape[0]
            num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim

            with init_context():
                image_projection = IPAdapterFaceIDImageProjection(
                    cross_attention_dim=cross_attention_dim,
                    image_embed_dim=id_embeddings_dim_in,
                    mult=multiplier,
                    num_tokens=num_tokens,
                )

            for key, value in state_dict.items():
                diffusers_name = key.replace("proj.0", "ff.net.0.proj")
                diffusers_name = diffusers_name.replace("proj.2", "ff.net.2")
                updated_state_dict[diffusers_name] = value

683
684
685
686
687
688
        else:
            # IP-Adapter Plus
            num_image_text_embeds = state_dict["latents"].shape[1]
            embed_dims = state_dict["proj_in.weight"].shape[1]
            output_dims = state_dict["proj_out.weight"].shape[0]
            hidden_dims = state_dict["latents"].shape[2]
689
690
691
692
693
694
            attn_key_present = any("attn" in k for k in state_dict)
            heads = (
                state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
                if attn_key_present
                else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
            )
695

696
697
698
699
700
701
702
703
            with init_context():
                image_projection = IPAdapterPlusImageProjection(
                    embed_dims=embed_dims,
                    output_dims=output_dims,
                    hidden_dims=hidden_dims,
                    heads=heads,
                    num_queries=num_image_text_embeds,
                )
704
705
706
707

            for key, value in state_dict.items():
                diffusers_name = key.replace("0.to", "2.to")

708
709
710
711
712
713
714
715
716
717
718
719
720
                diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
                diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
                diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
                diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
                diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
                diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
                diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
                diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")

                if "to_kv" in diffusers_name:
                    parts = diffusers_name.split(".")
                    parts[2] = "attn"
                    diffusers_name = ".".join(parts)
721
722
723
                    v_chunk = value.chunk(2, dim=0)
                    updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
                    updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
724
725
726
727
728
                elif "to_q" in diffusers_name:
                    parts = diffusers_name.split(".")
                    parts[2] = "attn"
                    diffusers_name = ".".join(parts)
                    updated_state_dict[diffusers_name] = value
729
                elif "to_out" in diffusers_name:
730
731
732
                    parts = diffusers_name.split(".")
                    parts[2] = "attn"
                    diffusers_name = ".".join(parts)
733
734
                    updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
                else:
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
                    diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
                    diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
                    diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")

                    diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
                    diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
                    diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")

                    diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
                    diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
                    diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")

                    diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
                    diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
                    diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
750
751
                    updated_state_dict[diffusers_name] = value

752
        if not low_cpu_mem_usage:
753
            image_projection.load_state_dict(updated_state_dict, strict=True)
754
        else:
755
756
            device_map = {"": self.device}
            load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
757
            empty_device_cache()
758

759
760
        return image_projection

761
    def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
762
763
764
        from ..models.attention_processor import (
            IPAdapterAttnProcessor,
            IPAdapterAttnProcessor2_0,
765
            IPAdapterXFormersAttnProcessor,
766
767
        )

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
        if low_cpu_mem_usage:
            if is_accelerate_available():
                from accelerate import init_empty_weights

            else:
                low_cpu_mem_usage = False
                logger.warning(
                    "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
                    " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
                    " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
                    " install accelerate\n```\n."
                )

        if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
            raise NotImplementedError(
                "Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
                " `low_cpu_mem_usage=False`."
            )

787
788
789
        # set ip-adapter cross-attention processors & load state_dict
        attn_procs = {}
        key_id = 1
790
        init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
791
792
793
794
795
796
797
798
799
800
        for name in self.attn_processors.keys():
            cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
            if name.startswith("mid_block"):
                hidden_size = self.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(self.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = self.config.block_out_channels[block_id]
801

802
            if cross_attention_dim is None or "motion_modules" in name:
YiYi Xu's avatar
YiYi Xu committed
803
                attn_processor_class = self.attn_processors[name].__class__
804
805
                attn_procs[name] = attn_processor_class()
            else:
806
807
808
809
810
811
812
813
                if "XFormers" in str(self.attn_processors[name].__class__):
                    attn_processor_class = IPAdapterXFormersAttnProcessor
                else:
                    attn_processor_class = (
                        IPAdapterAttnProcessor2_0
                        if hasattr(F, "scaled_dot_product_attention")
                        else IPAdapterAttnProcessor
                    )
814
815
816
817
818
819
820
821
                num_image_text_embeds = []
                for state_dict in state_dicts:
                    if "proj.weight" in state_dict["image_proj"]:
                        # IP-Adapter
                        num_image_text_embeds += [4]
                    elif "proj.3.weight" in state_dict["image_proj"]:
                        # IP-Adapter Full Face
                        num_image_text_embeds += [257]  # 256 CLIP tokens + 1 CLS token
822
823
824
825
826
827
                    elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]:
                        # IP-Adapter Face ID Plus
                        num_image_text_embeds += [4]
                    elif "norm.weight" in state_dict["image_proj"]:
                        # IP-Adapter Face ID
                        num_image_text_embeds += [4]
828
829
830
831
                    else:
                        # IP-Adapter Plus
                        num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]]

832
833
834
835
836
837
838
                with init_context():
                    attn_procs[name] = attn_processor_class(
                        hidden_size=hidden_size,
                        cross_attention_dim=cross_attention_dim,
                        scale=1.0,
                        num_tokens=num_image_text_embeds,
                    )
839
840

                value_dict = {}
841
842
843
                for i, state_dict in enumerate(state_dicts):
                    value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
                    value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
844

845
846
847
848
849
                if not low_cpu_mem_usage:
                    attn_procs[name].load_state_dict(value_dict)
                else:
                    device = next(iter(value_dict.values())).device
                    dtype = next(iter(value_dict.values())).dtype
850
851
                    device_map = {"": device}
                    load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
852

853
854
                key_id += 2

855
856
        empty_device_cache()

857
858
        return attn_procs

859
    def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
860
861
        if not isinstance(state_dicts, list):
            state_dicts = [state_dicts]
862
863
864
865
866
867
868
869
870

        # Kolors Unet already has a `encoder_hid_proj`
        if (
            self.encoder_hid_proj is not None
            and self.config.encoder_hid_dim_type == "text_proj"
            and not hasattr(self, "text_encoder_hid_proj")
        ):
            self.text_encoder_hid_proj = self.encoder_hid_proj

871
872
873
874
        # Set encoder_hid_proj after loading ip_adapter weights,
        # because `IPAdapterPlusImageProjection` also has `attn_processors`.
        self.encoder_hid_proj = None

875
        attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
876
877
        self.set_attn_processor(attn_procs)

878
        # convert IP-Adapter Image Projection layers to diffusers
879
880
        image_projection_layers = []
        for state_dict in state_dicts:
881
882
883
            image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
                state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
            )
884
            image_projection_layers.append(image_projection_layer)
885

886
        self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
887
        self.config.encoder_hid_dim_type = "ip_image_proj"
888
889

        self.to(dtype=self.dtype, device=self.device)
890

891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
    def _load_ip_adapter_loras(self, state_dicts):
        lora_dicts = {}
        for key_id, name in enumerate(self.attn_processors.keys()):
            for i, state_dict in enumerate(state_dicts):
                if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]:
                    if i not in lora_dicts:
                        lora_dicts[i] = {}
                    lora_dicts[i].update(
                        {
                            f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][
                                f"{key_id}.to_k_lora.down.weight"
                            ]
                        }
                    )
                    lora_dicts[i].update(
                        {
                            f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][
                                f"{key_id}.to_q_lora.down.weight"
                            ]
                        }
                    )
                    lora_dicts[i].update(
                        {
                            f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][
                                f"{key_id}.to_v_lora.down.weight"
                            ]
                        }
                    )
                    lora_dicts[i].update(
                        {
                            f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][
                                f"{key_id}.to_out_lora.down.weight"
                            ]
                        }
                    )
                    lora_dicts[i].update(
                        {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}
                    )
                    lora_dicts[i].update(
                        {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}
                    )
                    lora_dicts[i].update(
                        {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}
                    )
                    lora_dicts[i].update(
                        {
                            f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][
                                f"{key_id}.to_out_lora.up.weight"
                            ]
                        }
                    )
        return lora_dicts