Unverified Commit 814d710e authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] cache non lora pipeline outputs. (#12298)

* cache non lora pipeline outputs.

* up

* up

* up

* up

* Revert "up"

This reverts commit 772c32e43397f25919c29bbbe8ef9dc7d581cfb8.

* up

* Revert "up"

This reverts commit cca03df7fce55550ed28b59cadec12d1db188283.

* up

* up

* add .

* up

* up

* up

* up

* up

* up
parent cc5b31ff
...@@ -129,9 +129,6 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -129,9 +129,6 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
...@@ -122,9 +122,6 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -122,9 +122,6 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.transformer.add_adapter(denoiser_lora_config) pipe.transformer.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer") self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
...@@ -170,8 +167,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -170,8 +167,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
# Modify the config to have a layer which won't be present in the second LoRA we will load. # Modify the config to have a layer which won't be present in the second LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
...@@ -218,9 +214,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -218,9 +214,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = self.get_base_pipe_output()
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(output_no_lora.shape == self.output_shape)
# Modify the config to have a layer which won't be present in the first LoRA we will load. # Modify the config to have a layer which won't be present in the first LoRA we will load.
modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config) modified_denoiser_lora_config = copy.deepcopy(denoiser_lora_config)
...@@ -329,6 +323,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -329,6 +323,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
noise = floats_tensor((batch_size, num_channels) + sizes) noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator) input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
np.random.seed(0)
pipeline_inputs = { pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger", "prompt": "A painting of a squirrel eating a burger",
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")), "control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
......
...@@ -169,7 +169,7 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -169,7 +169,7 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
# only supported for `denoiser` now # only supported for `denoiser` now
......
...@@ -126,13 +126,20 @@ class PeftLoraLoaderMixinTests: ...@@ -126,13 +126,20 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, use_dora=False, lora_alpha=None): cached_non_lora_output = None
def get_base_pipe_output(self):
if self.cached_non_lora_output is None:
self.cached_non_lora_output = self._compute_baseline_output()
return self.cached_non_lora_output
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs: if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders: if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = self.scheduler_cls scheduler_cls = scheduler_cls if scheduler_cls is not None else self.scheduler_cls
rank = 4 rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha lora_alpha = rank if lora_alpha is None else lora_alpha
...@@ -238,15 +245,16 @@ class PeftLoraLoaderMixinTests: ...@@ -238,15 +245,16 @@ class PeftLoraLoaderMixinTests:
return noise, input_ids, pipeline_inputs return noise, input_ids, pipeline_inputs
# Copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb def _compute_baseline_output(self):
def get_dummy_tokens(self): components, _, _ = self.get_dummy_components(self.scheduler_cls)
max_seq_length = 77 pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) pipe.set_progress_bar_config(disable=None)
prepared_inputs = {} # Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
prepared_inputs["input_ids"] = inputs # explicitly.
return prepared_inputs _, _, inputs = self.get_dummy_inputs(with_generator=False)
return pipe(**inputs, generator=torch.manual_seed(0))[0]
def _get_lora_state_dicts(self, modules_to_save): def _get_lora_state_dicts(self, modules_to_save):
state_dicts = {} state_dicts = {}
...@@ -316,14 +324,8 @@ class PeftLoraLoaderMixinTests: ...@@ -316,14 +324,8 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
components, text_lora_config, _ = self.get_dummy_components() output_no_lora = self.get_base_pipe_output()
pipe = self.pipeline_class(**components) assert output_no_lora.shape == self.output_shape
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs)[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
""" """
...@@ -336,9 +338,7 @@ class PeftLoraLoaderMixinTests: ...@@ -336,9 +338,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -414,9 +414,6 @@ class PeftLoraLoaderMixinTests: ...@@ -414,9 +414,6 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -466,8 +463,7 @@ class PeftLoraLoaderMixinTests: ...@@ -466,8 +463,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
...@@ -503,8 +499,7 @@ class PeftLoraLoaderMixinTests: ...@@ -503,8 +499,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
...@@ -534,8 +529,7 @@ class PeftLoraLoaderMixinTests: ...@@ -534,8 +529,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
...@@ -566,9 +560,6 @@ class PeftLoraLoaderMixinTests: ...@@ -566,9 +560,6 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -616,8 +607,7 @@ class PeftLoraLoaderMixinTests: ...@@ -616,8 +607,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
...@@ -666,9 +656,6 @@ class PeftLoraLoaderMixinTests: ...@@ -666,9 +656,6 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -708,9 +695,6 @@ class PeftLoraLoaderMixinTests: ...@@ -708,9 +695,6 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -747,9 +731,7 @@ class PeftLoraLoaderMixinTests: ...@@ -747,9 +731,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -790,8 +772,7 @@ class PeftLoraLoaderMixinTests: ...@@ -790,8 +772,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
...@@ -825,8 +806,7 @@ class PeftLoraLoaderMixinTests: ...@@ -825,8 +806,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
...@@ -900,7 +880,7 @@ class PeftLoraLoaderMixinTests: ...@@ -900,7 +880,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
...@@ -1024,7 +1004,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1024,7 +1004,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
...@@ -1080,7 +1060,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1080,7 +1060,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
...@@ -1240,7 +1220,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1240,7 +1220,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
...@@ -1331,7 +1311,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1331,7 +1311,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
...@@ -1551,7 +1531,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1551,7 +1531,6 @@ class PeftLoraLoaderMixinTests:
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
@require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi( def test_simple_inference_with_text_lora_denoiser_fused_multi(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
): ):
...@@ -1565,9 +1544,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1565,9 +1544,6 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
...@@ -1641,8 +1617,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1641,8 +1617,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
...@@ -1685,7 +1660,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1685,7 +1660,6 @@ class PeftLoraLoaderMixinTests:
"LoRA should change the output", "LoRA should change the output",
) )
@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self): def test_simple_inference_with_dora(self):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
...@@ -1695,7 +1669,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1695,7 +1669,6 @@ class PeftLoraLoaderMixinTests:
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape) self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -1783,7 +1756,6 @@ class PeftLoraLoaderMixinTests: ...@@ -1783,7 +1756,6 @@ class PeftLoraLoaderMixinTests:
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
...@@ -1820,7 +1792,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1820,7 +1792,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft") logger = logging.get_logger("diffusers.loaders.peft")
...@@ -1832,7 +1804,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1832,7 +1804,7 @@ class PeftLoraLoaderMixinTests:
denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer") denoiser = getattr(pipe, "unet") if self.unet_kwargs is not None else getattr(pipe, "transformer")
self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}")) self.assertTrue(cap_logger.out.startswith(f"No LoRA keys associated to {denoiser.__class__.__name__}"))
self.assertTrue(np.allclose(original_out, out_after_lora_attempt, atol=1e-5, rtol=1e-5)) self.assertTrue(np.allclose(output_no_lora, out_after_lora_attempt, atol=1e-5, rtol=1e-5))
# test only for text encoder # test only for text encoder
for lora_module in self.pipeline_class._lora_loadable_modules: for lora_module in self.pipeline_class._lora_loadable_modules:
...@@ -1864,9 +1836,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1864,9 +1836,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5 lora_scale = 0.5
...@@ -2212,9 +2182,6 @@ class PeftLoraLoaderMixinTests: ...@@ -2212,9 +2182,6 @@ class PeftLoraLoaderMixinTests:
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline( pipe, _ = self.add_adapters_to_pipeline(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
) )
...@@ -2260,7 +2227,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2260,7 +2227,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = self.get_base_pipe_output()
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment