Unverified Commit ea6938ae authored by Donald.Lee's avatar Donald.Lee Committed by GitHub
Browse files

Fix: unet save_attn_procs at UNet2DconditionLoadersMixin (#8699)



* fix: unet save_attn_procs at custom diffusion

* style: recover unchanaged parts(max line length 119) / mod: add condition

* style: recover unchanaged parts(max line length 119)

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8ef0d9de
...@@ -457,6 +457,15 @@ class UNet2DConditionLoadersMixin: ...@@ -457,6 +457,15 @@ class UNet2DConditionLoadersMixin:
) )
if is_custom_diffusion: if is_custom_diffusion:
state_dict = self._get_custom_diffusion_state_dict() state_dict = self._get_custom_diffusion_state_dict()
if save_function is None and safe_serialization:
# safetensors does not support saving dicts with non-tensor values
empty_state_dict = {k: v for k, v in state_dict.items() if not isinstance(v, torch.Tensor)}
if len(empty_state_dict) > 0:
logger.warning(
f"Safetensors does not support saving dicts with non-tensor values. "
f"The following keys will be ignored: {empty_state_dict.keys()}"
)
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
else: else:
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.") raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment