Unverified Commit e0d8c910 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Peft] fix saving / loading when unet is not "unet" (#6046)



* [Peft] fix saving / loading when unet is not "unet"

* Update src/diffusers/loaders/lora.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* undo stablediffusion-xl changes

* use unet_name to get unet for lora helpers

* use unet_name

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent a3d31e3a
...@@ -149,9 +149,11 @@ class IPAdapterMixin: ...@@ -149,9 +149,11 @@ class IPAdapterMixin:
self.feature_extractor = CLIPImageProcessor() self.feature_extractor = CLIPImageProcessor()
# load ip-adapter into unet # load ip-adapter into unet
self.unet._load_ip_adapter_weights(state_dict) unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet._load_ip_adapter_weights(state_dict)
def set_ip_adapter_scale(self, scale): def set_ip_adapter_scale(self, scale):
for attn_processor in self.unet.attn_processors.values(): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
attn_processor.scale = scale attn_processor.scale = scale
...@@ -912,10 +912,10 @@ class LoraLoaderMixin: ...@@ -912,10 +912,10 @@ class LoraLoaderMixin:
) )
if unet_lora_layers: if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet")) state_dict.update(pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if transformer_lora_layers: if transformer_lora_layers:
state_dict.update(pack_weights(transformer_lora_layers, "transformer")) state_dict.update(pack_weights(transformer_lora_layers, "transformer"))
...@@ -975,6 +975,8 @@ class LoraLoaderMixin: ...@@ -975,6 +975,8 @@ class LoraLoaderMixin:
>>> ... >>> ...
``` ```
""" """
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"): if version.parse(__version__) > version.parse("0.23"):
logger.warn( logger.warn(
...@@ -982,13 +984,13 @@ class LoraLoaderMixin: ...@@ -982,13 +984,13 @@ class LoraLoaderMixin:
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT." "you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
) )
for _, module in self.unet.named_modules(): for _, module in unet.named_modules():
if hasattr(module, "set_lora_layer"): if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None) module.set_lora_layer(None)
else: else:
recurse_remove_peft_layers(self.unet) recurse_remove_peft_layers(unet)
if hasattr(self.unet, "peft_config"): if hasattr(unet, "peft_config"):
del self.unet.peft_config del unet.peft_config
# Safe to call the following regardless of LoRA. # Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch() self._remove_text_encoder_monkey_patch()
...@@ -1027,7 +1029,8 @@ class LoraLoaderMixin: ...@@ -1027,7 +1029,8 @@ class LoraLoaderMixin:
) )
if fuse_unet: if fuse_unet:
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
if USE_PEFT_BACKEND: if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
...@@ -1080,13 +1083,14 @@ class LoraLoaderMixin: ...@@ -1080,13 +1083,14 @@ class LoraLoaderMixin:
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
""" """
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
if unfuse_unet: if unfuse_unet:
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
self.unet.unfuse_lora() unet.unfuse_lora()
else: else:
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
for module in self.unet.modules(): for module in unet.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
module.unmerge() module.unmerge()
...@@ -1202,8 +1206,9 @@ class LoraLoaderMixin: ...@@ -1202,8 +1206,9 @@ class LoraLoaderMixin:
adapter_names: Union[List[str], str], adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None, adapter_weights: Optional[List[float]] = None,
): ):
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
# Handle the UNET # Handle the UNET
self.unet.set_adapters(adapter_names, adapter_weights) unet.set_adapters(adapter_names, adapter_weights)
# Handle the Text Encoder # Handle the Text Encoder
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
...@@ -1216,7 +1221,8 @@ class LoraLoaderMixin: ...@@ -1216,7 +1221,8 @@ class LoraLoaderMixin:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
# Disable unet adapters # Disable unet adapters
self.unet.disable_lora() unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.disable_lora()
# Disable text encoder adapters # Disable text encoder adapters
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
...@@ -1229,7 +1235,8 @@ class LoraLoaderMixin: ...@@ -1229,7 +1235,8 @@ class LoraLoaderMixin:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
# Enable unet adapters # Enable unet adapters
self.unet.enable_lora() unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.enable_lora()
# Enable text encoder adapters # Enable text encoder adapters
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
...@@ -1251,7 +1258,8 @@ class LoraLoaderMixin: ...@@ -1251,7 +1258,8 @@ class LoraLoaderMixin:
adapter_names = [adapter_names] adapter_names = [adapter_names]
# Delete unet adapters # Delete unet adapters
self.unet.delete_adapters(adapter_names) unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
unet.delete_adapters(adapter_names)
for adapter_name in adapter_names: for adapter_name in adapter_names:
# Delete text encoder adapters # Delete text encoder adapters
...@@ -1284,8 +1292,8 @@ class LoraLoaderMixin: ...@@ -1284,8 +1292,8 @@ class LoraLoaderMixin:
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
active_adapters = [] active_adapters = []
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for module in self.unet.modules(): for module in unet.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters active_adapters = module.active_adapters
break break
...@@ -1309,8 +1317,9 @@ class LoraLoaderMixin: ...@@ -1309,8 +1317,9 @@ class LoraLoaderMixin:
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"): if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys()) set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
set_adapters["unet"] = list(self.unet.peft_config.keys()) if hasattr(self, self.unet_name) and hasattr(unet, "peft_config"):
set_adapters[self.unet_name] = list(self.unet.peft_config.keys())
return set_adapters return set_adapters
...@@ -1331,7 +1340,8 @@ class LoraLoaderMixin: ...@@ -1331,7 +1340,8 @@ class LoraLoaderMixin:
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
# Handle the UNET # Handle the UNET
for unet_module in self.unet.modules(): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for unet_module in unet.modules():
if isinstance(unet_module, BaseTunerLayer): if isinstance(unet_module, BaseTunerLayer):
for adapter_name in adapter_names: for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device) unet_module.lora_A[adapter_name].to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment