Unverified Commit ea5a6a8b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] Cleanup lora tests utils (#11276)

* start cleaning up lora test utils for reusability

* update

* updates

* updates
parent b8093e66
...@@ -260,6 +260,31 @@ class PeftLoraLoaderMixinTests: ...@@ -260,6 +260,31 @@ class PeftLoraLoaderMixinTests:
return modules_to_save return modules_to_save
def check_if_adapters_added_correctly(
self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"
):
if text_lora_config is not None:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, adapter_name=adapter_name)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if denoiser_lora_config is not None:
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, adapter_name=adapter_name)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
else:
denoiser = None
if text_lora_config is not None and self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, adapter_name=adapter_name)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
return pipe, denoiser
def test_simple_inference(self): def test_simple_inference(self):
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
...@@ -289,16 +314,7 @@ class PeftLoraLoaderMixinTests: ...@@ -289,16 +314,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config) pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
...@@ -381,22 +397,7 @@ class PeftLoraLoaderMixinTests: ...@@ -381,22 +397,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -459,16 +460,7 @@ class PeftLoraLoaderMixinTests: ...@@ -459,16 +460,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config) pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
...@@ -506,15 +498,7 @@ class PeftLoraLoaderMixinTests: ...@@ -506,15 +498,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config) pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.fuse_lora() pipe.fuse_lora()
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
...@@ -546,19 +530,7 @@ class PeftLoraLoaderMixinTests: ...@@ -546,19 +530,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
...@@ -593,18 +565,7 @@ class PeftLoraLoaderMixinTests: ...@@ -593,18 +565,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -655,22 +616,20 @@ class PeftLoraLoaderMixinTests: ...@@ -655,22 +616,20 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config) pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` state_dict = {}
# supports missing layers (PR#8324). if "text_encoder" in self.pipeline_class._lora_loadable_modules:
state_dict = { # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
f"text_encoder.{module_name}": param # supports missing layers (PR#8324).
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() state_dict = {
if "text_model.encoder.layers.4" not in module_name f"text_encoder.{module_name}": param
} for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
if "text_model.encoder.layers.4" not in module_name
}
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
state_dict.update( state_dict.update(
{ {
f"text_encoder_2.{module_name}": param f"text_encoder_2.{module_name}": param
...@@ -694,7 +653,7 @@ class PeftLoraLoaderMixinTests: ...@@ -694,7 +653,7 @@ class PeftLoraLoaderMixinTests:
"Removing adapters should change the output", "Removing adapters should change the output",
) )
def test_simple_inference_save_pretrained(self): def test_simple_inference_save_pretrained_with_text_lora(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
...@@ -708,16 +667,7 @@ class PeftLoraLoaderMixinTests: ...@@ -708,16 +667,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config) pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config=None)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -726,10 +676,11 @@ class PeftLoraLoaderMixinTests: ...@@ -726,10 +676,11 @@ class PeftLoraLoaderMixinTests:
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device) pipe_from_pretrained.to(torch_device)
self.assertTrue( if "text_encoder" in self.pipeline_class._lora_loadable_modules:
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), self.assertTrue(
"Lora not correctly set in text encoder", check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
) "Lora not correctly set in text encoder",
)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
...@@ -759,22 +710,7 @@ class PeftLoraLoaderMixinTests: ...@@ -759,22 +710,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -820,22 +756,7 @@ class PeftLoraLoaderMixinTests: ...@@ -820,22 +756,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
...@@ -879,22 +800,7 @@ class PeftLoraLoaderMixinTests: ...@@ -879,22 +800,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
...@@ -932,22 +838,7 @@ class PeftLoraLoaderMixinTests: ...@@ -932,22 +838,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
...@@ -983,22 +874,7 @@ class PeftLoraLoaderMixinTests: ...@@ -983,22 +874,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -1104,6 +980,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1104,6 +980,8 @@ class PeftLoraLoaderMixinTests:
) )
def test_wrong_adapter_name_raises_error(self): def test_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0] scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
...@@ -1111,20 +989,9 @@ class PeftLoraLoaderMixinTests: ...@@ -1111,20 +989,9 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") )
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
with self.assertRaises(ValueError) as err_context: with self.assertRaises(ValueError) as err_context:
pipe.set_adapters("test") pipe.set_adapters("test")
...@@ -1132,10 +999,11 @@ class PeftLoraLoaderMixinTests: ...@@ -1132,10 +999,11 @@ class PeftLoraLoaderMixinTests:
self.assertTrue("not in the list of present adapters" in str(err_context.exception)) self.assertTrue("not in the list of present adapters" in str(err_context.exception))
# test this works. # test this works.
pipe.set_adapters("adapter-1") pipe.set_adapters(adapter_name)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0] _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_multiple_wrong_adapter_name_raises_error(self): def test_multiple_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0] scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
...@@ -1143,33 +1011,22 @@ class PeftLoraLoaderMixinTests: ...@@ -1143,33 +1011,22 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe, text_lora_config, denoiser_lora_config, adapter_name=adapter_name
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") )
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0} scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
logger = logging.get_logger("diffusers.loaders.lora_base") logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(30) logger.setLevel(30)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components) pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
wrong_components = sorted(set(scale_with_wrong_components.keys())) wrong_components = sorted(set(scale_with_wrong_components.keys()))
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. " msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
self.assertTrue(msg in str(cap_logger.out)) self.assertTrue(msg in str(cap_logger.out))
# test this works. # test this works.
pipe.set_adapters("adapter-1") pipe.set_adapters(adapter_name)
_ = pipe(**inputs, generator=torch.manual_seed(0))[0] _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
...@@ -1804,20 +1661,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1804,20 +1661,7 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape) self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe.text_encoder.add_adapter(text_lora_config) pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -1908,18 +1752,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1908,18 +1752,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config) pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
...@@ -2011,22 +1844,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2011,22 +1844,7 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, _ = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
lora_scale = 0.5 lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
...@@ -2211,22 +2029,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2211,22 +2029,7 @@ class PeftLoraLoaderMixinTests:
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if storage_dtype is not None: if storage_dtype is not None:
denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype) denoiser.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment