Unverified Commit 05e7a854 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[lora] fix: lora unloading behvaiour (#11822)

* fix: lora unloading behvaiour

* fix

* update
parent 76ec3d1f
...@@ -693,6 +693,8 @@ class PeftAdapterMixin: ...@@ -693,6 +693,8 @@ class PeftAdapterMixin:
recurse_remove_peft_layers(self) recurse_remove_peft_layers(self)
if hasattr(self, "peft_config"): if hasattr(self, "peft_config"):
del self.peft_config del self.peft_config
if hasattr(self, "_hf_peft_config_loaded"):
self._hf_peft_config_loaded = None
_maybe_remove_and_reapply_group_offloading(self) _maybe_remove_and_reapply_group_offloading(self)
......
...@@ -291,9 +291,7 @@ class PeftLoraLoaderMixinTests: ...@@ -291,9 +291,7 @@ class PeftLoraLoaderMixinTests:
return modules_to_save return modules_to_save
def check_if_adapters_added_correctly( def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
):
if text_lora_config is not None: if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name) pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
...@@ -345,7 +343,7 @@ class PeftLoraLoaderMixinTests: ...@@ -345,7 +343,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
...@@ -428,7 +426,7 @@ class PeftLoraLoaderMixinTests: ...@@ -428,7 +426,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -484,7 +482,7 @@ class PeftLoraLoaderMixinTests: ...@@ -484,7 +482,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
...@@ -522,7 +520,7 @@ class PeftLoraLoaderMixinTests: ...@@ -522,7 +520,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora() pipe.fuse_lora()
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
...@@ -554,7 +552,7 @@ class PeftLoraLoaderMixinTests: ...@@ -554,7 +552,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
...@@ -589,7 +587,7 @@ class PeftLoraLoaderMixinTests: ...@@ -589,7 +587,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -640,7 +638,7 @@ class PeftLoraLoaderMixinTests: ...@@ -640,7 +638,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {} state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
...@@ -691,7 +689,7 @@ class PeftLoraLoaderMixinTests: ...@@ -691,7 +689,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -734,7 +732,7 @@ class PeftLoraLoaderMixinTests: ...@@ -734,7 +732,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -775,7 +773,7 @@ class PeftLoraLoaderMixinTests: ...@@ -775,7 +773,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
...@@ -819,7 +817,7 @@ class PeftLoraLoaderMixinTests: ...@@ -819,7 +817,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
...@@ -857,7 +855,7 @@ class PeftLoraLoaderMixinTests: ...@@ -857,7 +855,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
...@@ -893,7 +891,7 @@ class PeftLoraLoaderMixinTests: ...@@ -893,7 +891,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
...@@ -1010,7 +1008,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1010,7 +1008,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.check_if_adapters_added_correctly( pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
) )
...@@ -1032,7 +1030,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1032,7 +1030,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.check_if_adapters_added_correctly( pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
) )
...@@ -1759,7 +1757,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1759,7 +1757,7 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape) self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -1850,7 +1848,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1850,7 +1848,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
...@@ -1937,7 +1935,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1937,7 +1935,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5 lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
...@@ -2119,7 +2117,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2119,7 +2117,7 @@ class PeftLoraLoaderMixinTests:
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
if storage_dtype is not None: if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
...@@ -2237,7 +2235,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2237,7 +2235,7 @@ class PeftLoraLoaderMixinTests:
) )
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe, _ = self.check_if_adapters_added_correctly( pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
) )
...@@ -2290,7 +2288,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2290,7 +2288,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly( pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
) )
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -2309,6 +2307,25 @@ class PeftLoraLoaderMixinTests: ...@@ -2309,6 +2307,25 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match." np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
) )
def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
# unload and then add.
pipe.unload_lora_weights()
pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_inference_load_delete_load_adapters(self): def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes: for scheduler_cls in self.scheduler_classes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment