Unverified Commit b09a2aa3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] fix `cross_attention_kwargs` problems and tighten tests (#7388)

* debugging

* let's see the numbers

* let's see the numbers

* let's see the numbers

* restrict tolerance.

* increase inference steps.

* shallow copy of cross_attentionkwargs

* remove print
parent 63b68468
...@@ -1178,6 +1178,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1178,6 +1178,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
# we're popping the `scale` instead of getting it because otherwise `scale` will be propagated # we're popping the `scale` instead of getting it because otherwise `scale` will be propagated
# to the internal blocks and will raise deprecation warnings. this will be confusing for our users. # to the internal blocks and will raise deprecation warnings. this will be confusing for our users.
if cross_attention_kwargs is not None: if cross_attention_kwargs is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
lora_scale = cross_attention_kwargs.pop("scale", 1.0) lora_scale = cross_attention_kwargs.pop("scale", 1.0)
else: else:
lora_scale = 1.0 lora_scale = 1.0
......
...@@ -158,7 +158,7 @@ class PeftLoraLoaderMixinTests: ...@@ -158,7 +158,7 @@ class PeftLoraLoaderMixinTests:
pipeline_inputs = { pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger", "prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2, "num_inference_steps": 5,
"guidance_scale": 6.0, "guidance_scale": 6.0,
"output_type": "np", "output_type": "np",
} }
...@@ -589,7 +589,7 @@ class PeftLoraLoaderMixinTests: ...@@ -589,7 +589,7 @@ class PeftLoraLoaderMixinTests:
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5} **inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images ).images
self.assertTrue( self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), not np.allclose(output_lora, output_lora_scale, atol=1e-4, rtol=1e-4),
"Lora + scale should change the output", "Lora + scale should change the output",
) )
...@@ -1300,6 +1300,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -1300,6 +1300,11 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipe.load_lora_weights(lora_id) pipe.load_lora_weights(lora_id)
pipe = pipe.to("cuda") pipe = pipe.to("cuda")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.unet),
"Lora not correctly set in UNet",
)
self.assertTrue( self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), self.check_if_lora_correctly_set(pipe.text_encoder),
"Lora not correctly set in text encoder 2", "Lora not correctly set in text encoder 2",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment