Unverified Commit 81cf3b2f authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] [LoRA] clean up the serialization stuff. (#9512)

* clean up the serialization stuff.

* better
parent 534848c3
......@@ -201,6 +201,32 @@ class PeftLoraLoaderMixinTests:
prepared_inputs["input_ids"] = inputs
return prepared_inputs
def _get_lora_state_dicts(self, modules_to_save):
state_dicts = {}
for module_name, module in modules_to_save.items():
if module is not None:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts
def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
if "text_encoder" in lora_loadable_modules and hasattr(pipe, "text_encoder"):
modules_to_save["text_encoder"] = pipe.text_encoder
if "text_encoder_2" in lora_loadable_modules and hasattr(pipe, "text_encoder_2"):
modules_to_save["text_encoder_2"] = pipe.text_encoder_2
if has_denoiser:
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
modules_to_save["unet"] = pipe.unet
if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
modules_to_save["transformer"] = pipe.transformer
return modules_to_save
def test_simple_inference(self):
"""
Tests a simple inference and makes sure it works as expected
......@@ -420,45 +446,21 @@ class PeftLoraLoaderMixinTests:
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
if self.has_two_text_encoders:
if "text_encoder_2" not in self.pipeline_class._lora_loadable_modules:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
safe_serialization=False,
)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
......@@ -614,54 +616,20 @@ class PeftLoraLoaderMixinTests:
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = (
get_peft_model_state_dict(pipe.text_encoder)
if "text_encoder" in self.pipeline_class._lora_loadable_modules
else None
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
denoiser_state_dict = get_peft_model_state_dict(denoiser)
saving_kwargs = {
"save_directory": tmpdirname,
"safe_serialization": False,
}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
saving_kwargs.update({"text_encoder_lora_layers": text_encoder_state_dict})
if self.unet_kwargs is not None:
saving_kwargs.update({"unet_lora_layers": denoiser_state_dict})
else:
saving_kwargs.update({"transformer_lora_layers": denoiser_state_dict})
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
saving_kwargs.update({"text_encoder_2_lora_layers": text_encoder_2_state_dict})
self.pipeline_class.save_lora_weights(**saving_kwargs)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment