Unverified Commit 6fe05b9b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] make `set_adapters()` robust on silent failures. (#9618)

* make set_adapters() robust on silent failures.

* fixes to tests

* flaky decorator.

* fix

* flaky to sd3.

* remove warning.

* sort

* quality

* skip test_simple_inference_with_text_denoiser_multi_adapter_block_lora

* skip testing unsupported features.

* raise warning instead of error.
parent 2bc82d63
......@@ -661,8 +661,20 @@ class LoraBaseMixin:
adapter_names: Union[List[str], str],
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
):
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
if isinstance(adapter_weights, dict):
components_passed = set(adapter_weights.keys())
lora_components = set(self._lora_loadable_modules)
invalid_components = sorted(components_passed - lora_components)
if invalid_components:
logger.warning(
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
"to the invalid components will be removed and ignored."
)
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
adapter_weights = copy.deepcopy(adapter_weights)
# Expand weights into a list, one entry per adapter
......@@ -697,12 +709,6 @@ class LoraBaseMixin:
for adapter_name, weights in zip(adapter_names, adapter_weights):
if isinstance(weights, dict):
component_adapter_weights = weights.pop(component, None)
if component_adapter_weights is not None and not hasattr(self, component):
logger.warning(
f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
)
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
logger.warning(
(
......
......@@ -155,3 +155,7 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Text encoder LoRA is not supported in CogVideoX.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
......@@ -262,6 +262,10 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"LoRA should lead to different results.",
)
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
......@@ -270,6 +274,10 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
......@@ -783,6 +791,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
self.assertTrue(pipe.transformer.x_embedder.weight.data.shape[1] == in_features * 2)
self.assertTrue(pipe.transformer.config.in_channels == in_features * 2)
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
......@@ -791,6 +803,10 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@unittest.skip("Not supported in Flux.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
@slow
@nightly
......
......@@ -136,3 +136,7 @@ class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@unittest.skip("Text encoder LoRA is not supported in Mochi.")
def test_simple_inference_with_text_lora_save_load(self):
pass
@unittest.skip("Not supported in CogVideoX.")
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
pass
......@@ -30,6 +30,7 @@ from diffusers import (
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
is_flaky,
nightly,
numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda,
......@@ -128,6 +129,10 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
def test_modify_padding_mode(self):
pass
@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()
@nightly
@require_torch_gpu
......
......@@ -37,6 +37,7 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
CaptureLogger,
is_flaky,
load_image,
nightly,
numpy_cosine_similarity_distance,
......@@ -111,6 +112,10 @@ class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
gc.collect()
torch.cuda.empty_cache()
@is_flaky
def test_multiple_wrong_adapter_name_raises_error(self):
super().test_multiple_wrong_adapter_name_raises_error()
@slow
@nightly
......
......@@ -1135,6 +1135,43 @@ class PeftLoraLoaderMixinTests:
pipe.set_adapters("adapter-1")
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_multiple_wrong_adapter_name_raises_error(self):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
logger = logging.get_logger("diffusers.loaders.lora_base")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.set_adapters("adapter-1", adapter_weights=scale_with_wrong_components)
wrong_components = sorted(set(scale_with_wrong_components.keys()))
msg = f"The following components in `adapter_weights` are not part of the pipeline: {wrong_components}. "
self.assertTrue(msg in str(cap_logger.out))
# test this works.
pipe.set_adapters("adapter-1")
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_simple_inference_with_text_denoiser_block_scale(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment