Unverified Commit 9c8eca70 authored by Sourab Mangrulkar's avatar Sourab Mangrulkar Committed by GitHub
Browse files

add lora delete feature (#5738)



* add lora delete feature

* added tests and changed condition

* deal with corner cases

* more corner cases

* rename to `delete_adapter_layers` for consistency

---------
Co-authored-by: default avataryounesbelkada <younesbelkada@gmail.com>
parent 069123f6
...@@ -36,6 +36,7 @@ from .utils import ( ...@@ -36,6 +36,7 @@ from .utils import (
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_peft, convert_state_dict_to_peft,
convert_unet_state_dict_to_peft, convert_unet_state_dict_to_peft,
delete_adapter_layers,
deprecate, deprecate,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
...@@ -752,6 +753,27 @@ class UNet2DConditionLoadersMixin: ...@@ -752,6 +753,27 @@ class UNet2DConditionLoadersMixin:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=True) set_adapter_layers(self, enabled=True)
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Deletes the LoRA layers of `adapter_name` for the unet.
Args:
adapter_names (`Union[List[str], str]`):
The names of the adapter to delete. Can be a single string or a list of strings
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(adapter_names, str):
adapter_names = [adapter_names]
for adapter_name in adapter_names:
delete_adapter_layers(self, adapter_name)
# Pop also the corresponding adapter from the config
if hasattr(self, "peft_config"):
self.peft_config.pop(adapter_name, None)
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
...@@ -2507,6 +2529,30 @@ class LoraLoaderMixin: ...@@ -2507,6 +2529,30 @@ class LoraLoaderMixin:
if hasattr(self, "text_encoder_2"): if hasattr(self, "text_encoder_2"):
self.enable_lora_for_text_encoder(self.text_encoder_2) self.enable_lora_for_text_encoder(self.text_encoder_2)
def delete_adapters(self, adapter_names: Union[List[str], str]):
"""
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
Args:
adapter_names (`Union[List[str], str]`):
The names of the adapter to delete. Can be a single string or a list of strings
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if isinstance(adapter_names, str):
adapter_names = [adapter_names]
# Delete unet adapters
self.unet.delete_adapters(adapter_names)
for adapter_name in adapter_names:
# Delete text encoder adapters
if hasattr(self, "text_encoder"):
delete_adapter_layers(self.text_encoder, adapter_name)
if hasattr(self, "text_encoder_2"):
delete_adapter_layers(self.text_encoder_2, adapter_name)
def get_active_adapters(self) -> List[str]: def get_active_adapters(self) -> List[str]:
""" """
Gets the list of the current active adapters. Gets the list of the current active adapters.
......
...@@ -89,6 +89,7 @@ from .logging import get_logger ...@@ -89,6 +89,7 @@ from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .peft_utils import ( from .peft_utils import (
check_peft_version, check_peft_version,
delete_adapter_layers,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
recurse_remove_peft_layers, recurse_remove_peft_layers,
......
...@@ -180,6 +180,28 @@ def set_adapter_layers(model, enabled=True): ...@@ -180,6 +180,28 @@ def set_adapter_layers(model, enabled=True):
module.disable_adapters = not enabled module.disable_adapters = not enabled
def delete_adapter_layers(model, adapter_name):
from peft.tuners.tuners_utils import BaseTunerLayer
for module in model.modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "delete_adapter"):
module.delete_adapter(adapter_name)
else:
raise ValueError(
"The version of PEFT you are using is not compatible, please use a version that is greater than 0.6.1"
)
# For transformers integration - we need to pop the adapter from the config
if getattr(model, "_hf_peft_config_loaded", False) and hasattr(model, "peft_config"):
model.peft_config.pop(adapter_name, None)
# In case all adapters are deleted, we need to delete the config
# and make sure to set the flag to False
if len(model.peft_config) == 0:
del model.peft_config
model._hf_peft_config_loaded = None
def set_weights_and_activate_adapters(model, adapter_names, weights): def set_weights_and_activate_adapters(model, adapter_names, weights):
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
......
...@@ -831,6 +831,96 @@ class PeftLoraLoaderMixinTests: ...@@ -831,6 +831,96 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(self.check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
self.check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.delete_adapters("adapter-1")
output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
pipe.delete_adapters("adapter-2")
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"])
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_unet_multi_adapter_weighted(self): def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment