Unverified Commit 09e777a3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Single scheduler in lora tests (#12315)

* single scheduler please.

* up

* up

* up
parent a72bc0c4
......@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
CogVideoXPipeline,
CogVideoXTransformer3DModel,
......@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = {
"num_attention_heads": 4,
......
......@@ -50,7 +50,6 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......@@ -124,30 +123,29 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for scheduler_cls in self.scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
@parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator
......
......@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
......@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"patch_size": 1,
"in_channels": 8,
......
......@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......@@ -141,33 +140,30 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
......@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {"shift": 7.0}
transformer_kwargs = {
"patch_size": 1,
"in_channels": 4,
......
......@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
......
......@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......
......@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {}
transformer_kwargs = {
......@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self):
scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
......
......@@ -26,8 +26,6 @@ from parameterized import parameterized
from diffusers import (
AutoencoderKL,
DDIMScheduler,
LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import logging
......@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = None
scheduler_kwargs = None
scheduler_classes = [DDIMScheduler, LCMScheduler]
has_two_text_encoders = False
has_three_text_encoders = False
......@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
def get_dummy_components(self, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
scheduler_cls = self.scheduler_cls
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
......@@ -319,152 +316,143 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs)[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs)[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
def test_simple_inference_with_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
@require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()},
"The LoRA params should be on 'meta' device.",
)
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
"No param should be on 'meta' device.",
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
self.assertTrue(
"meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
)
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder."
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()},
"meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.",
)
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
"meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.",
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
self.assertTrue(
"meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
)
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.",
)
te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.",
)
_, _, inputs = self.get_dummy_inputs()
output_lora = pipe(**inputs)[0]
self.assertTrue(output_lora.shape == self.output_shape)
_, _, inputs = self.get_dummy_inputs()
output_lora = pipe(**inputs)[0]
self.assertTrue(output_lora.shape == self.output_shape)
@require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage."""
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
# Now, check for `low_cpu_mem_usage.`
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
# Now, check for `low_cpu_mem_usage.`
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3
),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
)
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
)
def test_simple_inference_with_text_lora_and_scale(self):
"""
......@@ -472,147 +460,140 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_text_lora_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_text_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2",
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2",
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
def test_simple_inference_with_text_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_partial_text_lora(self):
"""
......@@ -620,142 +601,139 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
r=4,
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
lora_alpha=4,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=False,
)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, _, _ = self.get_dummy_components()
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
r=4,
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
lora_alpha=4,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=False,
)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
# supports missing layers (PR#8324).
state_dict = {
f"text_encoder.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
if "text_model.encoder.layers.4" not in module_name
}
state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
# supports missing layers (PR#8324).
state_dict = {
f"text_encoder.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
if "text_model.encoder.layers.4" not in module_name
}
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
state_dict.update(
{
f"text_encoder_2.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
if "text_model.encoder.layers.4" not in module_name
}
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
state_dict.update(
{
f"text_encoder_2.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
if "text_model.encoder.layers.4" not in module_name
}
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
# Unload lora and load it back using the pipe.load_lora_weights machinery
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict)
# Unload lora and load it back using the pipe.load_lora_weights machinery
pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict)
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
"Removing adapters should change the output",
)
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
"Removing adapters should change the output",
)
def test_simple_inference_save_pretrained_with_text_lora(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
"Lora not correctly set in text encoder",
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
"Lora not correctly set in text encoder",
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_text_denoiser_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
"""
......@@ -763,120 +741,112 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
"The scaling parameter has not been correctly restored!",
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
"The scaling parameter has not been correctly restored!",
)
def test_simple_inference_with_text_lora_denoiser_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
# Fusing should still keep the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
# Fusing should still keep the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_text_denoiser_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
pipe.unload_lora_weights()
# unloading should remove the LoRA layers
self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2",
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2",
)
output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output",
)
def test_simple_inference_with_text_denoiser_lora_unfused(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
......@@ -885,125 +855,120 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# unloading should remove the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
# unloading should remove the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
# Fuse and unfuse should lead to the same results
self.assertTrue(
np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
# Fuse and unfuse should lead to the same results
self.assertTrue(
np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
def test_simple_inference_with_text_denoiser_multi_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.",
)
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
def test_multiple_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1054,131 +1018,127 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora)
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
pipe.set_adapters("adapter-1", weights_1)
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
pipe.set_adapters("adapter-1", weights_1)
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
weights_2 = {"unet": {"up": 5}}
pipe.set_adapters("adapter-1", weights_2)
output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
weights_2 = {"unet": {"up": 5}}
pipe.set_adapters("adapter-1", weights_2)
output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
"LoRA weights 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 1 should give different results",
)
self.assertFalse(
np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 2 should give different results",
)
self.assertFalse(
np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
"LoRA weights 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 1 should give different results",
)
self.assertFalse(
np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 2 should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora)
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}}
scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}}
pipe.set_adapters("adapter-1", scales_1)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-1", scales_1)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2", scales_2)
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2", scales_2)
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
# a mismatching number of adapter_names and adapter_weights should raise an error
with self.assertRaises(ValueError):
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
# a mismatching number of adapter_names and adapter_weights should raise an error
with self.assertRaises(ValueError):
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
......@@ -1274,170 +1234,164 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.delete_adapters("adapter-1")
output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-1")
output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertTrue(
np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
pipe.delete_adapters("adapter-2")
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.delete_adapters("adapter-2")
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"])
pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"])
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
# Fuse and unfuse should lead to the same results
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results",
)
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Weighted adapter and mixed adapter should give different results",
)
self.assertFalse(
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Weighted adapter and mixed adapter should give different results",
)
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results",
)
@skip_mps
@pytest.mark.xfail(
......@@ -1445,164 +1399,157 @@ class PeftLoraLoaderMixinTests:
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False,
)
def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
def test_lora_fuse_nan(self):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
if self.unet_kwargs:
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
"adapter-1"
].weight += float("inf")
else:
named_modules = [name for name, _ in pipe.transformer.named_modules()]
possible_tower_names = [
"transformer_blocks",
"blocks",
"joint_transformer_blocks",
"single_transformer_blocks",
]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
if len(filtered_tower_names) == 0:
reason = (
f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
)
raise ValueError(reason)
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
has_attn1 = any("attn1" in name for name in named_modules)
if has_attn1:
transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
else:
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
if self.unet_kwargs:
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"inf"
)
else:
named_modules = [name for name, _ in pipe.transformer.named_modules()]
possible_tower_names = [
"transformer_blocks",
"blocks",
"joint_transformer_blocks",
"single_transformer_blocks",
]
filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
]
if len(filtered_tower_names) == 0:
reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
raise ValueError(reason)
for tower_name in filtered_tower_names:
transformer_tower = getattr(pipe.transformer, tower_name)
has_attn1 = any("attn1" in name for name in named_modules)
if has_attn1:
transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
else:
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
def test_get_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"])
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"])
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"])
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"])
pipe.set_adapters(["adapter-1", "adapter-2"])
self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
pipe.set_adapters(["adapter-1", "adapter-2"])
self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self):
"""
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# 1.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
# 1.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"unet": ["adapter-1"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"transformer": ["adapter-1"]})
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"unet": ["adapter-1"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"transformer": ["adapter-1"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
# 2.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
# 2.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
# 3.
pipe.set_adapters(["adapter-1", "adapter-2"])
# 3.
pipe.set_adapters(["adapter-1", "adapter-2"])
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual(
pipe.get_list_adapters(),
dicts_to_be_checked,
)
self.assertDictEqual(
pipe.get_list_adapters(),
dicts_to_be_checked,
)
# 4.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
# 4.
dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
@require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi(
......@@ -1612,8 +1559,83 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1"])
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1627,150 +1649,65 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1"])
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
)
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
self.assertFalse(
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
"LoRA should change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
pipe.set_adapters(["adapter-1"])
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
self.assertFalse(
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
"LoRA should change the output",
)
@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, use_dora=True
)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(
np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
"DoRA lora should change the output",
)
self.assertFalse(
np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
"DoRA lora should change the output",
)
def test_missing_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1842,23 +1778,22 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works..
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Just makes sure it works.
_ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self):
def set_pad_mode(network, mode="circular"):
......@@ -1866,22 +1801,20 @@ class PeftLoraLoaderMixinTests:
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
for scheduler_cls in self.scheduler_classes:
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_pad_mode = "circular"
set_pad_mode(pipe.vae, _pad_mode)
set_pad_mode(pipe.unet, _pad_mode)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_pad_mode = "circular"
set_pad_mode(pipe.vae, _pad_mode)
set_pad_mode(pipe.unet, _pad_mode)
_, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs)[0]
_, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`.
components, _, _ = self.get_dummy_components(scheduler_cls)
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1925,73 +1858,71 @@ class PeftLoraLoaderMixinTests:
def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertFalse(
np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
pipe.set_adapters("default", lora_scale)
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should match the output of `set_adapters()`.",
)
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertFalse(
np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
pipe.set_adapters("default", lora_scale)
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as attention_kwargs.",
)
self.assertTrue(
np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should match the output of `set_adapters()`.",
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as set_adapters().",
)
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as attention_kwargs.",
)
self.assertTrue(
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as set_adapters().",
)
@require_peft_version_greater("0.13.2")
def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
......@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
# 1. Test forward with add_adapter
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
......@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None)
......@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components)
pipe, _ = self.add_adapters_to_pipeline(
......@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
......@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
......@@ -2330,51 +2254,48 @@ class PeftLoraLoaderMixinTests:
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
# First, delete adapter and compare.
pipe.delete_adapters(pipe.get_active_adapters()[0])
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
# First, delete adapter and compare.
pipe.delete_adapters(pipe.get_active_adapters()[0])
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
# Then load adapter and compare.
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
# Then load adapter and compare.
pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
......@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
onload_device = torch_device
offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0])
components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
......@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
@require_torch_accelerator
def test_lora_loading_model_cpu_offload(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
# reinitialize the pipeline to mimic the inference workflow.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0])
components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment