Unverified Commit a0542c19 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Remove legacy LoRA code and related adjustments (#8316)

* remove legacy code from load_attn_procs.

* finish first draft

* fix more.

* fix more

* add test

* add serialization support.

* fix-copies

* require peft backend for lora tests

* style

* fix test

* fix loading.

* empty

* address benjamin's feedback.
parent a8ad6664
...@@ -111,3 +111,21 @@ jobs: ...@@ -111,3 +111,21 @@ jobs:
-s -v \ -s -v \
--make-reports=tests_${{ matrix.config.report }} \ --make-reports=tests_${{ matrix.config.report }} \
tests/lora/ tests/lora/
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_models_lora_${{ matrix.config.report }} \
tests/models/ -k "lora"
- name: Failure short reports
if: ${{ failure() }}
run: |
cat reports/tests_${{ matrix.config.report }}_failures_short.txt
cat reports/tests_models_lora_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v2
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
\ No newline at end of file
...@@ -189,12 +189,17 @@ jobs: ...@@ -189,12 +189,17 @@ jobs:
-s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \ -s -v -k "not Flax and not Onnx and not PEFTLoRALoading" \
--make-reports=tests_peft_cuda \ --make-reports=tests_peft_cuda \
tests/lora/ tests/lora/
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
-s -v -k "lora and not Flax and not Onnx and not PEFTLoRALoading" \
--make-reports=tests_peft_cuda_models_lora \
tests/models/
- name: Failure short reports - name: Failure short reports
if: ${{ failure() }} if: ${{ failure() }}
run: | run: |
cat reports/tests_peft_cuda_stats.txt cat reports/tests_peft_cuda_stats.txt
cat reports/tests_peft_cuda_failures_short.txt cat reports/tests_peft_cuda_failures_short.txt
cat reports/tests_peft_cuda_models_lora_failures_short.txt
- name: Test suite reports artifacts - name: Test suite reports artifacts
if: ${{ always() }} if: ${{ always() }}
......
...@@ -22,17 +22,14 @@ import torch ...@@ -22,17 +22,14 @@ import torch
from huggingface_hub import model_info from huggingface_hub import model_info
from huggingface_hub.constants import HF_HUB_OFFLINE from huggingface_hub.constants import HF_HUB_OFFLINE
from huggingface_hub.utils import validate_hf_hub_args from huggingface_hub.utils import validate_hf_hub_args
from packaging import version
from torch import nn from torch import nn
from .. import __version__ from ..models.modeling_utils import load_state_dict
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_peft, convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
...@@ -119,13 +116,10 @@ class LoraLoaderMixin: ...@@ -119,13 +116,10 @@ class LoraLoaderMixin:
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
self.load_lora_into_unet( self.load_lora_into_unet(
state_dict, state_dict,
network_alphas=network_alphas, network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
...@@ -136,7 +130,6 @@ class LoraLoaderMixin: ...@@ -136,7 +130,6 @@ class LoraLoaderMixin:
if not hasattr(self, "text_encoder") if not hasattr(self, "text_encoder")
else self.text_encoder, else self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
...@@ -193,16 +186,8 @@ class LoraLoaderMixin: ...@@ -193,16 +186,8 @@ class LoraLoaderMixin:
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): weight_name (`str`, *optional*, defaults to None):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also Name of the serialized state dict file.
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
mirror (`str`, *optional*):
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
information.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both. # UNet and text encoder or both.
...@@ -383,9 +368,7 @@ class LoraLoaderMixin: ...@@ -383,9 +368,7 @@ class LoraLoaderMixin:
return (is_model_cpu_offload, is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod @classmethod
def load_lora_into_unet( def load_lora_into_unet(cls, state_dict, network_alphas, unet, adapter_name=None, _pipeline=None):
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
...@@ -395,14 +378,11 @@ class LoraLoaderMixin: ...@@ -395,14 +378,11 @@ class LoraLoaderMixin:
into the unet or prefixed with an additional `unet` which can be used to distinguish between text into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers. encoder lora layers.
network_alphas (`Dict[str, float]`): network_alphas (`Dict[str, float]`):
See `LoRALinearLayer` for more details. The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into. The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
...@@ -410,94 +390,18 @@ class LoraLoaderMixin: ...@@ -410,94 +390,18 @@ class LoraLoaderMixin:
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as # then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
# their prefixes. # their prefixes.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys): if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
# Load the layers corresponding to UNet. # Load the layers corresponding to UNet.
logger.info(f"Loading {cls.unet_name}.") logger.info(f"Loading {cls.unet_name}.")
unet.load_attn_procs(
unet_keys = [k for k in keys if k.startswith(cls.unet_name)] state_dict, network_alphas=network_alphas, adapter_name=adapter_name, _pipeline=_pipeline
state_dict = {k.replace(f"{cls.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys} )
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.unet_name)]
network_alphas = {
k.replace(f"{cls.unet_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
else:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
if not USE_PEFT_BACKEND:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warning(warn_message)
if len(state_dict.keys()) > 0:
if adapter_name in getattr(unet, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
state_dict = convert_unet_state_dict_to_peft(state_dict)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(unet)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
@classmethod @classmethod
def load_lora_into_text_encoder( def load_lora_into_text_encoder(
...@@ -507,7 +411,6 @@ class LoraLoaderMixin: ...@@ -507,7 +411,6 @@ class LoraLoaderMixin:
text_encoder, text_encoder,
prefix=None, prefix=None,
lora_scale=1.0, lora_scale=1.0,
low_cpu_mem_usage=None,
adapter_name=None, adapter_name=None,
_pipeline=None, _pipeline=None,
): ):
...@@ -527,11 +430,6 @@ class LoraLoaderMixin: ...@@ -527,11 +430,6 @@ class LoraLoaderMixin:
lora_scale (`float`): lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer. lora layer.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
...@@ -541,8 +439,6 @@ class LoraLoaderMixin: ...@@ -541,8 +439,6 @@ class LoraLoaderMixin:
from peft import LoraConfig from peft import LoraConfig
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes. # their prefixes.
...@@ -625,9 +521,7 @@ class LoraLoaderMixin: ...@@ -625,9 +521,7 @@ class LoraLoaderMixin:
# Unsafe code /> # Unsafe code />
@classmethod @classmethod
def load_lora_into_transformer( def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None):
cls, state_dict, network_alphas, transformer, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -640,19 +534,12 @@ class LoraLoaderMixin: ...@@ -640,19 +534,12 @@ class LoraLoaderMixin:
See `LoRALinearLayer` for more details. See `LoRALinearLayer` for more details.
unet (`UNet2DConditionModel`): unet (`UNet2DConditionModel`):
The UNet model to load the LoRA layers into. The UNet model to load the LoRA layers into.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
keys = list(state_dict.keys()) keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)]
...@@ -846,22 +733,11 @@ class LoraLoaderMixin: ...@@ -846,22 +733,11 @@ class LoraLoaderMixin:
>>> ... >>> ...
``` ```
""" """
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"): raise ValueError("PEFT backend is required for this method.")
logger.warning(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)
for _, module in unet.named_modules(): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if hasattr(module, "set_lora_layer"): unet.unload_lora()
module.set_lora_layer(None)
else:
recurse_remove_peft_layers(unet)
if hasattr(unet, "peft_config"):
del unet.peft_config
# Safe to call the following regardless of LoRA. # Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch() self._remove_text_encoder_monkey_patch()
......
...@@ -33,34 +33,32 @@ from ..models.embeddings import ( ...@@ -33,34 +33,32 @@ from ..models.embeddings import (
IPAdapterPlusImageProjection, IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection, MultiIPAdapterImageProjection,
) )
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
from ..utils import ( from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
is_peft_version,
is_torch_version, is_torch_version,
logging, logging,
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
from .unet_loader_utils import _maybe_expand_lora_scales from .unet_loader_utils import _maybe_expand_lora_scales
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
...@@ -79,7 +77,8 @@ class UNet2DConditionLoadersMixin: ...@@ -79,7 +77,8 @@ class UNet2DConditionLoadersMixin:
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
defined in defined in
[`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py) [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
and be a `torch.nn.Module` class. and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
`peft`: `pip install -U peft`.
Parameters: Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
...@@ -110,20 +109,20 @@ class UNet2DConditionLoadersMixin: ...@@ -110,20 +109,20 @@ class UNet2DConditionLoadersMixin:
token (`str` or *bool*, *optional*): token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used. `diffusers-cli login` (stored in `~/.huggingface`) is used.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
mirror (`str`, *optional*): network_alphas (`Dict[str, float]`):
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not The value of the network alpha used for stable learning and preventing underflow. This value has the
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
information. link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
adapter_name (`str`, *optional*, defaults to None):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
Example: Example:
...@@ -139,9 +138,6 @@ class UNet2DConditionLoadersMixin: ...@@ -139,9 +138,6 @@ class UNet2DConditionLoadersMixin:
) )
``` ```
""" """
from ..models.attention_processor import CustomDiffusionAttnProcessor
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", None) resume_download = kwargs.pop("resume_download", None)
...@@ -152,15 +148,9 @@ class UNet2DConditionLoadersMixin: ...@@ -152,15 +148,9 @@ class UNet2DConditionLoadersMixin:
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) adapter_name = kwargs.pop("adapter_name", None)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None) _pipeline = kwargs.pop("_pipeline", None)
network_alphas = kwargs.pop("network_alphas", None)
is_network_alphas_none = network_alphas is None
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
...@@ -216,198 +206,196 @@ class UNet2DConditionLoadersMixin: ...@@ -216,198 +206,196 @@ class UNet2DConditionLoadersMixin:
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
# fill attn processors
lora_layers_list = []
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if is_lora: if is_custom_diffusion:
# correct keys attn_processors = self._process_custom_diffusion(state_dict=state_dict)
state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas) elif is_lora:
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
state_dict=state_dict,
unet_identifier_key=self.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
_pipeline=_pipeline,
)
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
)
if network_alphas is not None: # <Unsafe code
network_alphas_keys = list(network_alphas.keys()) # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
used_network_alphas_keys = set() # Now we remove any existing hooks to `_pipeline`.
lora_grouped_dict = defaultdict(dict)
mapped_network_alphas = {}
all_keys = list(state_dict.keys())
for key in all_keys:
value = state_dict.pop(key)
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
if network_alphas is not None:
for k in network_alphas_keys:
if k.replace(".alpha", "") in key:
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
used_network_alphas_keys.add(k)
if not is_network_alphas_none:
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
raise ValueError(
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
)
if len(state_dict) > 0: # For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
raise ValueError( if is_custom_diffusion and _pipeline is not None:
f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}" is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
)
for key, value_dict in lora_grouped_dict.items(): # only custom diffusion needs to set attn processors
attn_processor = self self.set_attn_processor(attn_processors)
for sub_key in key.split("."): self.to(dtype=self.dtype, device=self.device)
attn_processor = getattr(attn_processor, sub_key)
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
# or add_{k,v,q,out_proj}_proj_lora layers.
rank = value_dict["lora.down.weight"].shape[0]
if isinstance(attn_processor, LoRACompatibleConv):
in_features = attn_processor.in_channels
out_features = attn_processor.out_channels
kernel_size = attn_processor.kernel_size
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRAConv2dLayer(
in_features=in_features,
out_features=out_features,
rank=rank,
kernel_size=kernel_size,
stride=attn_processor.stride,
padding=attn_processor.padding,
network_alpha=mapped_network_alphas.get(key),
)
elif isinstance(attn_processor, LoRACompatibleLinear):
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
with ctx():
lora = LoRALinearLayer(
attn_processor.in_features,
attn_processor.out_features,
rank,
mapped_network_alphas.get(key),
)
else:
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()} # Offload back.
lora_layers_list.append((attn_processor, lora)) if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
if low_cpu_mem_usage: def _process_custom_diffusion(self, state_dict):
device = next(iter(value_dict.values())).device from ..models.attention_processor import CustomDiffusionAttnProcessor
dtype = next(iter(value_dict.values())).dtype
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion: attn_processors = {}
attn_processors = {} custom_diffusion_grouped_dict = defaultdict(dict)
custom_diffusion_grouped_dict = defaultdict(dict) for key, value in state_dict.items():
for key, value in state_dict.items(): if len(value) == 0:
if len(value) == 0: custom_diffusion_grouped_dict[key] = {}
custom_diffusion_grouped_dict[key] = {} else:
if "to_out" in key:
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
else: else:
if "to_out" in key: attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:]) custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
else:
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in custom_diffusion_grouped_dict.items(): for key, value_dict in custom_diffusion_grouped_dict.items():
if len(value_dict) == 0: if len(value_dict) == 0:
attn_processors[key] = CustomDiffusionAttnProcessor( attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
) )
else: else:
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1] cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0] hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
attn_processors[key] = CustomDiffusionAttnProcessor( attn_processors[key] = CustomDiffusionAttnProcessor(
train_kv=True, train_kv=True,
train_q_out=train_q_out, train_q_out=train_q_out,
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
attn_processors[key].load_state_dict(value_dict) attn_processors[key].load_state_dict(value_dict)
elif USE_PEFT_BACKEND:
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict` return attn_processors
# on the Unet
pass def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
else: # This method does the following things:
raise ValueError( # 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." # format. For legacy format no filtering is applied.
) # 2. Converts the `state_dict` to the `peft` compatible format.
# 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
# `LoraConfig` specs.
# 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
keys = list(state_dict.keys())
unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
unet_state_dict = {
k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
if network_alphas is not None:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
network_alphas = {
k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `load_lora_weights_into_unet` if len(state_dict_to_be_used) > 0:
if not USE_PEFT_BACKEND: if adapter_name in getattr(self, "peft_config", {}):
if _pipeline is not None: raise ValueError(
for _, component in _pipeline.components.items(): f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): )
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = (
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info( state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors if network_alphas is not None:
if is_custom_diffusion: # The alphas state dict have the same structure as Unet, thus we convert it to peft format using
self.set_attn_processor(attn_processors) # `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
# set lora layers
for target_module, lora_layer in lora_layers_list: rank = {}
target_module.set_lora_layer(lora_layer) for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
if is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
if is_peft_version("<", "0.9.0"):
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
self.to(dtype=self.dtype, device=self.device) return is_model_cpu_offload, is_sequential_cpu_offload
# Offload back. @classmethod
if is_model_cpu_offload: # Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
_pipeline.enable_model_cpu_offload() def _optionally_disable_offloading(cls, _pipeline):
elif is_sequential_cpu_offload: """
_pipeline.enable_sequential_cpu_offload() Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): Args:
is_new_lora_format = all( _pipeline (`DiffusionPipeline`):
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() The pipeline to disable offloading for.
)
if is_new_lora_format:
# Strip the `"unet"` prefix.
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
if is_text_encoder_present:
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
logger.warning(warn_message)
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
# change processor format to 'pure' LoRACompatibleLinear format Returns:
if any("processor" in k.split(".") for k in state_dict.keys()): tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
def format_to_lora_compatible(key): if _pipeline is not None and _pipeline.hf_device_map is None:
if "processor" not in key.split("."): for _, component in _pipeline.components.items():
return key if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora") if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()} logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
if network_alphas is not None: return (is_model_cpu_offload, is_sequential_cpu_offload)
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
return state_dict, network_alphas
def save_attn_procs( def save_attn_procs(
self, self,
...@@ -460,6 +448,23 @@ class UNet2DConditionLoadersMixin: ...@@ -460,6 +448,23 @@ class UNet2DConditionLoadersMixin:
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
is_custom_diffusion = any(
isinstance(
x,
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
)
for (_, x) in self.attn_processors.items()
)
if is_custom_diffusion:
state_dict = self._get_custom_diffusion_state_dict()
else:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
from peft.utils import get_peft_model_state_dict
state_dict = get_peft_model_state_dict(self)
if save_function is None: if save_function is None:
if safe_serialization: if safe_serialization:
...@@ -471,36 +476,6 @@ class UNet2DConditionLoadersMixin: ...@@ -471,36 +476,6 @@ class UNet2DConditionLoadersMixin:
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
is_custom_diffusion = any(
isinstance(
x,
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
)
for (_, x) in self.attn_processors.items()
)
if is_custom_diffusion:
model_to_save = AttnProcsLayers(
{
y: x
for (y, x) in self.attn_processors.items()
if isinstance(
x,
(
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
),
)
}
)
state_dict = model_to_save.state_dict()
for name, attn in self.attn_processors.items():
if len(attn.state_dict()) == 0:
state_dict[name] = {}
else:
model_to_save = AttnProcsLayers(self.attn_processors)
state_dict = model_to_save.state_dict()
if weight_name is None: if weight_name is None:
if safe_serialization: if safe_serialization:
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
...@@ -512,56 +487,84 @@ class UNet2DConditionLoadersMixin: ...@@ -512,56 +487,84 @@ class UNet2DConditionLoadersMixin:
save_function(state_dict, save_path) save_function(state_dict, save_path)
logger.info(f"Model weights saved in {save_path}") logger.info(f"Model weights saved in {save_path}")
def _get_custom_diffusion_state_dict(self):
from ..models.attention_processor import (
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
)
model_to_save = AttnProcsLayers(
{
y: x
for (y, x) in self.attn_processors.items()
if isinstance(
x,
(
CustomDiffusionAttnProcessor,
CustomDiffusionAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
),
)
}
)
state_dict = model_to_save.state_dict()
for name, attn in self.attn_processors.items():
if len(attn.state_dict()) == 0:
state_dict[name] = {}
return state_dict
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None): def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `fuse_lora()`.")
self.lora_scale = lora_scale self.lora_scale = lora_scale
self._safe_fusing = safe_fusing self._safe_fusing = safe_fusing
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names)) self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
def _fuse_lora_apply(self, module, adapter_names=None): def _fuse_lora_apply(self, module, adapter_names=None):
if not USE_PEFT_BACKEND: from peft.tuners.tuners_utils import BaseTunerLayer
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale, self._safe_fusing) merge_kwargs = {"safe_merge": self._safe_fusing}
if adapter_names is not None: if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError( raise ValueError(
"The `adapter_names` argument is not supported in your environment. Please switch" "The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to PEFT backend to use this argument by installing latest PEFT and transformers." " to the latest version of PEFT. `pip install -U peft`"
" `pip install -U peft transformers`"
) )
else:
from peft.tuners.tuners_utils import BaseTunerLayer
merge_kwargs = {"safe_merge": self._safe_fusing}
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
# For BC with prevous PEFT versions, we need to check the signature
# of the `merge` method to see if it supports the `adapter_names` argument.
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
if "adapter_names" in supported_merge_kwargs:
merge_kwargs["adapter_names"] = adapter_names
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
raise ValueError(
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
" to the latest version of PEFT. `pip install -U peft`"
)
module.merge(**merge_kwargs) module.merge(**merge_kwargs)
def unfuse_lora(self): def unfuse_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
self.apply(self._unfuse_lora_apply) self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module): def _unfuse_lora_apply(self, module):
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
module.unmerge()
def unload_lora(self):
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
if hasattr(module, "_unfuse_lora"): raise ValueError("PEFT backend is required for `unload_lora()`.")
module._unfuse_lora()
else: from ..utils import recurse_remove_peft_layers
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer): recurse_remove_peft_layers(self)
module.unmerge() if hasattr(self, "peft_config"):
del self.peft_config
def set_adapters( def set_adapters(
self, self,
......
...@@ -903,17 +903,6 @@ class UNet2DConditionModel( ...@@ -903,17 +903,6 @@ class UNet2DConditionModel(
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(
"unload_lora",
"0.28.0",
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
)
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
def get_time_embed( def get_time_embed(
self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
......
...@@ -22,7 +22,7 @@ import torch.utils.checkpoint ...@@ -22,7 +22,7 @@ import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ...utils import BaseOutput, deprecate, logging from ...utils import BaseOutput, logging
from ..activations import get_activation from ..activations import get_activation
from ..attention_processor import ( from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
...@@ -546,18 +546,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -546,18 +546,6 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors) self.set_attn_processor(self.original_attn_processors)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
def unload_lora(self):
"""Unloads LoRA weights."""
deprecate(
"unload_lora",
"0.28.0",
"Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
)
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
def forward( def forward(
self, self,
sample: torch.Tensor, sample: torch.Tensor,
......
...@@ -37,7 +37,9 @@ from diffusers.utils.testing_utils import ( ...@@ -37,7 +37,9 @@ from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
is_peft_available,
load_hf_numpy, load_hf_numpy,
require_peft_backend,
require_torch_accelerator, require_torch_accelerator,
require_torch_accelerator_with_fp16, require_torch_accelerator_with_fp16,
require_torch_accelerator_with_training, require_torch_accelerator_with_training,
...@@ -51,11 +53,38 @@ from diffusers.utils.testing_utils import ( ...@@ -51,11 +53,38 @@ from diffusers.utils.testing_utils import (
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
if is_peft_available():
from peft import LoraConfig
from peft.tuners.tuners_utils import BaseTunerLayer
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
enable_full_determinism() enable_full_determinism()
def get_unet_lora_config():
rank = 4
unet_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
return unet_lora_config
def check_if_lora_correctly_set(model) -> bool:
"""
Checks if the LoRA layers are correctly set with peft
"""
for module in model.modules():
if isinstance(module, BaseTunerLayer):
return True
return False
def create_ip_adapter_state_dict(model): def create_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights) # "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {} ip_cross_attn_state_dict = {}
...@@ -1005,6 +1034,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -1005,6 +1034,65 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_peft_backend
def test_lora(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample = model(**inputs_dict).sample
assert not torch.allclose(
non_lora_sample, lora_sample, atol=1e-4, rtol=1e-4
), "LoRA injected UNet should produce different results."
@require_peft_backend
def test_lora_serialization(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
# forward pass without LoRA
with torch.no_grad():
non_lora_sample = model(**inputs_dict).sample
unet_lora_config = get_unet_lora_config()
model.add_adapter(unet_lora_config)
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
# forward pass with LoRA
with torch.no_grad():
lora_sample_1 = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.unload_lora()
model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
with torch.no_grad():
lora_sample_2 = model(**inputs_dict).sample
assert not torch.allclose(
non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4
), "LoRA injected UNet should produce different results."
assert torch.allclose(
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
), "Loading from a saved checkpoint should produce identical results."
@slow @slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase): class UNet2DConditionModelIntegrationTests(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment