Unverified Commit b47f5115 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Lora] fix lora fuse unfuse (#5003)



* fix lora fuse unfuse

* add same changes to loaders.py

* add test

---------
Co-authored-by: default avatarmultimodalart <joaopaulo.passos+multimodal@gmail.com>
parent 324aef6d
......@@ -121,7 +121,7 @@ class PatchedLoraProjection(nn.Module):
self.lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.regular_linear_layer.weight.data
......
......@@ -139,7 +139,7 @@ class LoRACompatibleConv(nn.Conv2d):
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data
......@@ -204,7 +204,7 @@ class LoRACompatibleLinear(nn.Linear):
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data
......
......@@ -43,7 +43,7 @@ from diffusers.models.attention_processor import (
LoRAAttnProcessor2_0,
XFormersAttnProcessor,
)
from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device
def create_unet_lora_layers(unet: nn.Module):
......@@ -1497,3 +1497,41 @@ class LoraIntegrationTests(unittest.TestCase):
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
@nightly
def test_sequential_fuse_unfuse(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
# 1. round
pipe.load_lora_weights("Pclanglais/TintinIA")
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice = images[0, -3:, -3:, -1].flatten()
pipe.unfuse_lora()
# 2. round
pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style")
pipe.fuse_lora()
pipe.unfuse_lora()
# 3. round
pipe.load_lora_weights("ostris/crayon_style_lora_sdxl")
pipe.fuse_lora()
pipe.unfuse_lora()
# 4. back to 1st round
pipe.load_lora_weights("Pclanglais/TintinIA")
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images_2 = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice_2 = images_2[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment