Unverified Commit 692b7a90 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Feat] add: utility for unloading lora. (#4034)

* add: test for testing unloading lora.

* add :reason to skipif.

* initial implementation of lora unload().

* apply styling.

* add: doc.

* change checkpoints.

* reinit generator

* finalize slow test.

* add fast test for unloading lora.
parent 71c918b8
...@@ -280,6 +280,10 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is ...@@ -280,6 +280,10 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs, **Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
refer to the respective docstrings. refer to the respective docstrings.
## Unloading LoRA parameters
You can call [`~diffusers.loaders.LoraLoaderMixin.unload_lora_weights`] on a pipeline to unload the LoRA parameters.
## Supporting A1111 themed LoRA checkpoints from Diffusers ## Supporting A1111 themed LoRA checkpoints from Diffusers
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
......
...@@ -25,6 +25,8 @@ from torch import nn ...@@ -25,6 +25,8 @@ from torch import nn
from .models.attention_processor import ( from .models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
AttnProcessor,
AttnProcessor2_0,
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
LoRAAttnAddedKVProcessor, LoRAAttnAddedKVProcessor,
...@@ -1270,6 +1272,38 @@ class LoraLoaderMixin: ...@@ -1270,6 +1272,38 @@ class LoraLoaderMixin:
new_state_dict = {**unet_state_dict, **te_state_dict} new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alpha return new_state_dict, network_alpha
def unload_lora_weights(self):
"""
Unloads the LoRA parameters.
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
>>> pipeline.unload_lora_weights()
>>> ...
```
"""
is_unet_lora = all(
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor, LoRAAttnAddedKVProcessor))
for _, processor in self.unet.attn_processors.items()
)
# Handle attention processors that are a mix of regular attention and AddedKV
# attention.
if is_unet_lora:
is_attn_procs_mixed = all(
isinstance(processor, (LoRAAttnProcessor2_0, LoRAAttnProcessor))
for _, processor in self.unet.attn_processors.items()
)
if not is_attn_procs_mixed:
unet_attn_proc_cls = AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
self.unet.set_attn_processor(unet_attn_proc_cls())
else:
self.unet.set_default_attn_processor()
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
class FromSingleFileMixin: class FromSingleFileMixin:
""" """
......
...@@ -83,9 +83,9 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module): ...@@ -83,9 +83,9 @@ def create_text_encoder_lora_layers(text_encoder: nn.Module):
return text_encoder_lora_layers return text_encoder_lora_layers
def set_lora_weights(text_lora_attn_parameters, randn_weight=False): def set_lora_weights(lora_attn_parameters, randn_weight=False):
with torch.no_grad(): with torch.no_grad():
for parameter in text_lora_attn_parameters: for parameter in lora_attn_parameters:
if randn_weight: if randn_weight:
parameter[:] = torch.randn_like(parameter) parameter[:] = torch.randn_like(parameter)
else: else:
...@@ -155,7 +155,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -155,7 +155,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
} }
return pipeline_components, lora_components return pipeline_components, lora_components
def get_dummy_inputs(self): def get_dummy_inputs(self, with_generator=True):
batch_size = 1 batch_size = 1
sequence_length = 10 sequence_length = 10
num_channels = 4 num_channels = 4
...@@ -167,16 +167,16 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -167,16 +167,16 @@ class LoraLoaderMixinTests(unittest.TestCase):
pipeline_inputs = { pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger", "prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2, "num_inference_steps": 2,
"guidance_scale": 6.0, "guidance_scale": 6.0,
"output_type": "numpy", "output_type": "np",
} }
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
# copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb
def get_dummy_tokens(self): def get_dummy_tokens(self):
max_seq_length = 77 max_seq_length = 77
...@@ -399,6 +399,45 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -399,6 +399,45 @@ class LoraLoaderMixinTests(unittest.TestCase):
) )
self.assertIsInstance(module.processor, attn_proc_class) self.assertIsInstance(module.processor, attn_proc_class)
def test_unload_lora(self):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionPipeline(**pipeline_components)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Unload LoRA parameters.
sd_pipe.unload_lora_weights()
original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice_two = original_images_two[0, -3:, -3:, -1]
assert not np.allclose(
orig_image_slice, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert not np.allclose(
orig_image_slice_two, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert np.allclose(
orig_image_slice, orig_image_slice_two, atol=1e-3
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") @unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU")
def test_lora_unet_attn_processors_with_xformers(self): def test_lora_unet_attn_processors_with_xformers(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -537,3 +576,35 @@ class LoraIntegrationTests(unittest.TestCase): ...@@ -537,3 +576,35 @@ class LoraIntegrationTests(unittest.TestCase):
expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583]) expected = np.array([0.7406, 0.699, 0.5963, 0.7493, 0.7045, 0.6096, 0.6886, 0.6388, 0.583])
self.assertTrue(np.allclose(images, expected, atol=1e-4)) self.assertTrue(np.allclose(images, expected, atol=1e-4))
def test_unload_lora(self):
generator = torch.manual_seed(0)
prompt = "masterpiece, best quality, mountain"
num_inference_steps = 2
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None).to(
torch_device
)
initial_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
initial_images = initial_images[0, -3:, -3:, -1].flatten()
lora_model_id = "hf-internal-testing/civitai-colored-icons-lora"
lora_filename = "Colored_Icons_by_vizsumit.safetensors"
pipe.load_lora_weights(lora_model_id, weight_name=lora_filename)
lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
lora_images = lora_images[0, -3:, -3:, -1].flatten()
pipe.unload_lora_weights()
generator = torch.manual_seed(0)
unloaded_lora_images = pipe(
prompt, output_type="np", generator=generator, num_inference_steps=num_inference_steps
).images
unloaded_lora_images = unloaded_lora_images[0, -3:, -3:, -1].flatten()
self.assertFalse(np.allclose(initial_images, lora_images))
self.assertTrue(np.allclose(initial_images, unloaded_lora_images, atol=1e-3))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment