"...text-generation-inference.git" did not exist on "a6e4d63c86f4eeaae2ba1337a39f19d03bbd2277"
Unverified Commit 2daedc0a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] make set_adapters() method more robust. (#9535)



* make set_adapters() method more robust.

* remove patch

* better and concise code.

* Update src/diffusers/loaders/lora_base.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 665c6b47
...@@ -532,13 +532,19 @@ class LoraBaseMixin: ...@@ -532,13 +532,19 @@ class LoraBaseMixin:
) )
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]} list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
all_adapters = { # eg ["adapter1", "adapter2"]
adapter for adapters in list_adapters.values() for adapter in adapters all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
} # eg ["adapter1", "adapter2"] missing_adapters = set(adapter_names) - all_adapters
if len(missing_adapters) > 0:
raise ValueError(
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
)
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
invert_list_adapters = { invert_list_adapters = {
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters] adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
for adapter in all_adapters for adapter in all_adapters
} # eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]} }
# Decompose weights into weights for denoiser and text encoders. # Decompose weights into weights for denoiser and text encoders.
_component_adapter_weights = {} _component_adapter_weights = {}
......
...@@ -929,12 +929,24 @@ class PeftLoraLoaderMixinTests: ...@@ -929,12 +929,24 @@ class PeftLoraLoaderMixinTests:
pipe.set_adapters("adapter-1") pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
pipe.set_adapters("adapter-2") pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
# Fuse and unfuse should lead to the same results # Fuse and unfuse should lead to the same results
self.assertFalse( self.assertFalse(
...@@ -960,6 +972,38 @@ class PeftLoraLoaderMixinTests: ...@@ -960,6 +972,38 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_wrong_adapter_name_raises_error(self):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
with self.assertRaises(ValueError) as err_context:
pipe.set_adapters("test")
self.assertTrue("not in the list of present adapters" in str(err_context.exception))
# test this works.
pipe.set_adapters("adapter-1")
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_simple_inference_with_text_denoiser_block_scale(self): def test_simple_inference_with_text_denoiser_block_scale(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment