Unverified Commit bc7a4d49 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT`] Fix scale unscale with LoRA adapters (#5417)



* fix scale unscale v1

* final fixes + CI

* fix slow trst

* oops

* fix copies

* oops

* oops

* fix

* style

* fix copies

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 8dba1808
......@@ -405,7 +405,7 @@ class StableDiffusionLDM3DPipeline(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -374,7 +374,7 @@ class StableDiffusionModelEditingPipeline(DiffusionPipeline, TextualInversionLoa
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -358,7 +358,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -389,7 +389,7 @@ class StableDiffusionParadigmsPipeline(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -579,7 +579,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -381,7 +381,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, TextualInversionLoaderMixin)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -372,7 +372,7 @@ class StableDiffusionUpscalePipeline(
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -479,7 +479,7 @@ class StableUnCLIPPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -433,7 +433,7 @@ class StableUnCLIPImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -436,12 +436,12 @@ class StableDiffusionXLPipeline(
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -440,12 +440,12 @@ class StableDiffusionXLImg2ImgPipeline(
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -590,12 +590,12 @@ class StableDiffusionXLInpaintPipeline(
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -429,7 +429,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -450,12 +450,12 @@ class StableDiffusionXLAdapterPipeline(
if self.text_encoder is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
if self.text_encoder_2 is not None:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder_2)
unscale_lora_layers(self.text_encoder_2, lora_scale)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
......
......@@ -361,7 +361,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -423,7 +423,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -556,7 +556,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, negative_prompt_embeds
......
......@@ -1371,7 +1371,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (sample,)
......
......@@ -16,6 +16,7 @@ PEFT utilities: Utilities related to peft library
"""
import collections
import importlib
from typing import Optional
from packaging import version
......@@ -91,21 +92,28 @@ def scale_lora_layers(model, weight):
module.scale_layer(weight)
def unscale_lora_layers(model):
def unscale_lora_layers(model, weight: Optional[float] = None):
"""
Removes the previously passed weight given to the LoRA layers of the model.
Args:
model (`torch.nn.Module`):
The model to scale.
weight (`float`):
The weight to be given to the LoRA layers.
weight (`float`, *optional*):
The weight to be given to the LoRA layers. If no scale is passed the scale of the lora layer will be
re-initialized to the correct value. If 0.0 is passed, we will re-initialize the scale with the correct
value.
"""
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
module.unscale_layer()
if weight is not None and weight != 0:
module.unscale_layer(weight)
elif weight is not None and weight == 0:
for adapter_name in module.active_adapters:
# if weight == 0 unscale should re-set the scale to the original value.
module.set_scale(adapter_name, 1.0)
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
......@@ -184,7 +192,7 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_adapter(adapter_name)
else:
module.active_adapter = adapter_name
module.scale_layer(weight)
module.set_scale(adapter_name, weight)
# set multiple active adapters
for module in model.modules():
......
......@@ -775,6 +775,79 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Weighted adapter and mixed adapter should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_lora_fuse_nan(self):
components, _, text_lora_config, unet_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
......@@ -1073,7 +1146,6 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
expected_slice_scale = np.array([0.538, 0.539, 0.540, 0.540, 0.542, 0.539, 0.538, 0.541, 0.539])
predicted_slice = images[0, -3:, -3:, -1].flatten()
# import pdb; pdb.set_trace()
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
......@@ -1106,7 +1178,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.5977, 0.5985, 0.6039, 0.5976, 0.6025, 0.6036, 0.5946, 0.5979, 0.5998])
expected_slice_scale = np.array([0.5888, 0.5897, 0.5946, 0.5888, 0.5935, 0.5946, 0.5857, 0.5891, 0.5909])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
# Lora disabled
......@@ -1120,7 +1192,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
output_type="np",
).images
predicted_slice = images[0, -3:, -3:, -1].flatten()
expected_slice_scale = np.array([0.54625, 0.5473, 0.5495, 0.5465, 0.5476, 0.5461, 0.5452, 0.5485, 0.5493])
expected_slice_scale = np.array([0.5456, 0.5466, 0.5487, 0.5458, 0.5469, 0.5454, 0.5446, 0.5479, 0.5487])
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment