Unverified Commit 02ba50c6 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT` / `LoRA`] Fix civitai bug when network alpha is an empty dict (#5608)



* fix civitai bug

* add test

* up

* fix test

* added slow test.

* style

* Update src/diffusers/utils/peft_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/utils/peft_utils.py

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent 4f2bf673
...@@ -129,7 +129,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True ...@@ -129,7 +129,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None: if network_alpha_dict is not None and len(network_alpha_dict) > 0:
if len(set(network_alpha_dict.values())) > 1: if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times # get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
......
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
...@@ -1772,6 +1773,28 @@ class LoraSDXLIntegrationTests(unittest.TestCase): ...@@ -1772,6 +1773,28 @@ class LoraSDXLIntegrationTests(unittest.TestCase):
self.assertTrue(np.allclose(images, expected, atol=1e-3)) self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipe) release_memory(pipe)
def test_sd_load_civitai_empty_network_alpha(self):
"""
This test simply checks that loading a LoRA with an empty network alpha works fine
See: https://github.com/huggingface/diffusers/issues/5606
"""
pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda")
pipeline.enable_sequential_cpu_offload()
civitai_path = hf_hub_download("ybelkada/test-ahi-civitai", "ahi_lora_weights.safetensors")
pipeline.load_lora_weights(civitai_path, adapter_name="ahri")
images = pipeline(
"ahri, masterpiece, league of legends",
output_type="np",
generator=torch.manual_seed(156),
num_inference_steps=5,
).images
images = images[0, -3:, -3:, -1].flatten()
expected = np.array([0.0, 0.0, 0.0, 0.002557, 0.020954, 0.001792, 0.006581, 0.00591, 0.002995])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
release_memory(pipeline)
def test_canny_lora(self): def test_canny_lora(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment