Unverified Commit a4da2161 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] improve LoRA fusion tests (#11274)



* improve lora fusion tests

* more improvements.

* remove comment

* update

* relax tolerance.

* num_fused_loras as a property
Co-authored-by: default avatarBenjaminBossan <benjamin.bossan@gmail.com>

* updates

* update

* fix

* fix
Co-authored-by: default avatarBenjaminBossan <benjamin.bossan@gmail.com>

* Update src/diffusers/loaders/lora_base.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

---------
Co-authored-by: default avatarBenjaminBossan <benjamin.bossan@gmail.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent 5939ace9
...@@ -465,7 +465,7 @@ class LoraBaseMixin: ...@@ -465,7 +465,7 @@ class LoraBaseMixin:
"""Utility class for handling LoRAs.""" """Utility class for handling LoRAs."""
_lora_loadable_modules = [] _lora_loadable_modules = []
num_fused_loras = 0 _merged_adapters = set()
def load_lora_weights(self, **kwargs): def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.") raise NotImplementedError("`load_lora_weights()` is not implemented.")
...@@ -592,6 +592,9 @@ class LoraBaseMixin: ...@@ -592,6 +592,9 @@ class LoraBaseMixin:
if len(components) == 0: if len(components) == 0:
raise ValueError("`components` cannot be an empty list.") raise ValueError("`components` cannot be an empty list.")
# Need to retrieve the names as `adapter_names` can be None. So we cannot directly use it
# in `self._merged_adapters = self._merged_adapters | merged_adapter_names`.
merged_adapter_names = set()
for fuse_component in components: for fuse_component in components:
if fuse_component not in self._lora_loadable_modules: if fuse_component not in self._lora_loadable_modules:
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.") raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
...@@ -601,13 +604,19 @@ class LoraBaseMixin: ...@@ -601,13 +604,19 @@ class LoraBaseMixin:
# check if diffusers model # check if diffusers model
if issubclass(model.__class__, ModelMixin): if issubclass(model.__class__, ModelMixin):
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names) model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
for module in model.modules():
if isinstance(module, BaseTunerLayer):
merged_adapter_names.update(set(module.merged_adapters))
# handle transformers models. # handle transformers models.
if issubclass(model.__class__, PreTrainedModel): if issubclass(model.__class__, PreTrainedModel):
fuse_text_encoder_lora( fuse_text_encoder_lora(
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
for module in model.modules():
if isinstance(module, BaseTunerLayer):
merged_adapter_names.update(set(module.merged_adapters))
self.num_fused_loras += 1 self._merged_adapters = self._merged_adapters | merged_adapter_names
def unfuse_lora(self, components: List[str] = [], **kwargs): def unfuse_lora(self, components: List[str] = [], **kwargs):
r""" r"""
...@@ -661,9 +670,18 @@ class LoraBaseMixin: ...@@ -661,9 +670,18 @@ class LoraBaseMixin:
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)): if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
for module in model.modules(): for module in model.modules():
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
for adapter in set(module.merged_adapters):
if adapter and adapter in self._merged_adapters:
self._merged_adapters = self._merged_adapters - {adapter}
module.unmerge() module.unmerge()
self.num_fused_loras -= 1 @property
def num_fused_loras(self):
return len(self._merged_adapters)
@property
def fused_loras(self):
return self._merged_adapters
def set_adapters( def set_adapters(
self, self,
......
...@@ -124,6 +124,9 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -124,6 +124,9 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_simple_inference_with_text_denoiser_lora_unfused(self): def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3) super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
def test_lora_scale_kwargs_match_fusion(self):
super().test_lora_scale_kwargs_match_fusion(expected_atol=9e-3, expected_rtol=9e-3)
@unittest.skip("Not supported in CogVideoX.") @unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
pass pass
......
...@@ -117,6 +117,40 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): ...@@ -117,6 +117,40 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
def test_multiple_wrong_adapter_name_raises_error(self): def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error() super().test_multiple_wrong_adapter_name_raises_error()
def test_simple_inference_with_text_denoiser_lora_unfused(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_simple_inference_with_text_denoiser_lora_unfused(
expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_simple_inference_with_text_lora_denoiser_fused_multi(
expected_atol=expected_atol, expected_rtol=expected_rtol
)
def test_lora_scale_kwargs_match_fusion(self):
if torch.cuda.is_available():
expected_atol = 9e-2
expected_rtol = 9e-2
else:
expected_atol = 1e-3
expected_rtol = 1e-3
super().test_lora_scale_kwargs_match_fusion(expected_atol=expected_atol, expected_rtol=expected_rtol)
@slow @slow
@nightly @nightly
......
...@@ -80,6 +80,18 @@ def initialize_dummy_state_dict(state_dict): ...@@ -80,6 +80,18 @@ def initialize_dummy_state_dict(state_dict):
POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"] POSSIBLE_ATTENTION_KWARGS_NAMES = ["cross_attention_kwargs", "joint_attention_kwargs", "attention_kwargs"]
def determine_attention_kwargs_name(pipeline_class):
call_signature_keys = inspect.signature(pipeline_class.__call__).parameters.keys()
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
return attention_kwargs_name
@require_peft_backend @require_peft_backend
class PeftLoraLoaderMixinTests: class PeftLoraLoaderMixinTests:
pipeline_class = None pipeline_class = None
...@@ -442,14 +454,7 @@ class PeftLoraLoaderMixinTests: ...@@ -442,14 +454,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
# TODO(diffusers): Discuss a common naming convention across library for 1.0.0 release
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
for scheduler_cls in self.scheduler_classes: for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
...@@ -740,12 +745,7 @@ class PeftLoraLoaderMixinTests: ...@@ -740,12 +745,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + Unet + scale argument Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected and makes sure it works as expected
""" """
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
for scheduler_cls in self.scheduler_classes: for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
...@@ -878,9 +878,11 @@ class PeftLoraLoaderMixinTests: ...@@ -878,9 +878,11 @@ class PeftLoraLoaderMixinTests:
pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.check_if_adapters_added_correctly(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
...@@ -1608,26 +1610,21 @@ class PeftLoraLoaderMixinTests: ...@@ -1608,26 +1610,21 @@ class PeftLoraLoaderMixinTests:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
# Attach a second adapter
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components: if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
# set them to multi-adapter inference mode # set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
...@@ -1637,6 +1634,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1637,6 +1634,7 @@ class PeftLoraLoaderMixinTests:
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers so outpout should remain the same # Fusing should still keep the LoRA layers so outpout should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -1647,9 +1645,23 @@ class PeftLoraLoaderMixinTests: ...@@ -1647,9 +1645,23 @@ class PeftLoraLoaderMixinTests:
) )
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
pipe.fuse_lora( pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"] components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
) )
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
...@@ -1657,6 +1669,63 @@ class PeftLoraLoaderMixinTests: ...@@ -1657,6 +1669,63 @@ class PeftLoraLoaderMixinTests:
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol), np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output", "Fused lora should not change the output",
) )
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
pipe.set_adapters(["adapter-1"])
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
self.assertFalse(
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
"LoRA should change the output",
)
@require_peft_version_greater(peft_version="0.9.0") @require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self): def test_simple_inference_with_dora(self):
...@@ -1838,12 +1907,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1838,12 +1907,7 @@ class PeftLoraLoaderMixinTests:
def test_set_adapters_match_attention_kwargs(self): def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match.""" """Test to check if outputs after `set_adapters()` and attention kwargs match."""
call_signature_keys = inspect.signature(self.pipeline_class.__call__).parameters.keys() attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for possible_attention_kwargs in POSSIBLE_ATTENTION_KWARGS_NAMES:
if possible_attention_kwargs in call_signature_keys:
attention_kwargs_name = possible_attention_kwargs
break
assert attention_kwargs_name is not None
for scheduler_cls in self.scheduler_classes: for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment