Unverified Commit 09e777a3 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[tests] Single scheduler in lora tests (#12315)

* single scheduler please.

* up

* up

* up
parent a72bc0c4
...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -43,7 +43,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class AuraFlowLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = AuraFlowPipeline pipeline_class = AuraFlowPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -21,7 +21,6 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import ( from diffusers import (
AutoencoderKLCogVideoX, AutoencoderKLCogVideoX,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler, CogVideoXDPMScheduler,
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXTransformer3DModel, CogVideoXTransformer3DModel,
...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -44,7 +43,6 @@ class CogVideoXLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogVideoXPipeline pipeline_class = CogVideoXPipeline
scheduler_cls = CogVideoXDPMScheduler scheduler_cls = CogVideoXDPMScheduler
scheduler_kwargs = {"timestep_spacing": "trailing"} scheduler_kwargs = {"timestep_spacing": "trailing"}
scheduler_classes = [CogVideoXDDIMScheduler, CogVideoXDPMScheduler]
transformer_kwargs = { transformer_kwargs = {
"num_attention_heads": 4, "num_attention_heads": 4,
......
...@@ -50,7 +50,6 @@ class TokenizerWrapper: ...@@ -50,7 +50,6 @@ class TokenizerWrapper:
class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = CogView4Pipeline pipeline_class = CogView4Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -124,30 +123,29 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -124,30 +123,29 @@ class CogView4LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device) pipe_from_pretrained.to(torch_device)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0] images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
@parameterized.expand([("block_level", True), ("leaf_level", False)]) @parameterized.expand([("block_level", True), ("leaf_level", False)])
@require_torch_accelerator @require_torch_accelerator
......
...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -55,9 +55,8 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend @require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -282,9 +281,8 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = FluxControlPipeline pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler() scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 8, "in_channels": 8,
......
...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -51,7 +51,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class HunyuanVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HunyuanVideoPipeline pipeline_class = HunyuanVideoPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class LTXVideoLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = LTXPipeline pipeline_class = LTXPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa ...@@ -39,7 +39,6 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = Lumina2Pipeline pipeline_class = Lumina2Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -141,33 +140,30 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -141,33 +140,30 @@ class Lumina2LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
strict=False, strict=False,
) )
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
if "text_encoder" in self.pipeline_class._lora_loadable_modules: pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
) denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") # corrupt one LoRA weight with `inf` values
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") with torch.no_grad():
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
# corrupt one LoRA weight with `inf` values
with torch.no_grad(): # with `safe_fusing=True` we should see an Error
pipe.transformer.layers[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError): # without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0]
# without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False) self.assertTrue(np.isnan(out).all())
out = pipe(**inputs)[0]
self.assertTrue(np.isnan(out).all())
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class MochiLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = MochiPipeline pipeline_class = MochiPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -37,7 +37,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = QwenImagePipeline pipeline_class = QwenImagePipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -31,9 +31,8 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend @require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = SanaPipeline pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0) scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {"shift": 7.0}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"patch_size": 1, "patch_size": 1,
"in_channels": 4, "in_channels": 4,
......
...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -55,7 +55,6 @@ class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {} scheduler_kwargs = {}
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
transformer_kwargs = { transformer_kwargs = {
"sample_size": 32, "sample_size": 32,
"patch_size": 1, "patch_size": 1,
......
...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -42,7 +42,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanPipeline pipeline_class = WanPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
......
...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402 ...@@ -50,7 +50,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
scheduler_kwargs = {} scheduler_kwargs = {}
transformer_kwargs = { transformer_kwargs = {
...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): ...@@ -165,9 +164,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
@require_peft_version_greater("0.13.2") @require_peft_version_greater("0.13.2")
def test_lora_exclude_modules_wanvace(self): def test_lora_exclude_modules_wanvace(self):
scheduler_cls = self.scheduler_classes[0]
exclude_module_name = "vace_blocks.0.proj_out" exclude_module_name = "vace_blocks.0.proj_out"
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
......
...@@ -26,8 +26,6 @@ from parameterized import parameterized ...@@ -26,8 +26,6 @@ from parameterized import parameterized
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMScheduler,
LCMScheduler,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import logging from diffusers.utils import logging
...@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests: ...@@ -109,7 +107,6 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = None scheduler_cls = None
scheduler_kwargs = None scheduler_kwargs = None
scheduler_classes = [DDIMScheduler, LCMScheduler]
has_two_text_encoders = False has_two_text_encoders = False
has_three_text_encoders = False has_three_text_encoders = False
...@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests: ...@@ -129,13 +126,13 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): def get_dummy_components(self, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs: if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders: if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.") raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls scheduler_cls = self.scheduler_cls
rank = 4 rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha lora_alpha = rank if lora_alpha is None else lora_alpha
...@@ -319,152 +316,143 @@ class PeftLoraLoaderMixinTests: ...@@ -319,152 +316,143 @@ class PeftLoraLoaderMixinTests:
""" """
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs() _, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs)[0] output_no_lora = pipe(**inputs)[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
def test_simple_inference_with_text_lora(self): def test_simple_inference_with_text_lora(self):
""" """
Tests a simple inference with lora attached on the text encoder Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
) )
@require_peft_version_greater("0.13.1") @require_peft_version_greater("0.13.1")
def test_low_cpu_mem_usage_with_injection(self): def test_low_cpu_mem_usage_with_injection(self):
"""Tests if we can inject LoRA state dict with low_cpu_mem_usage.""" """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True) inject_adapter_in_model(text_lora_config, pipe.text_encoder, low_cpu_mem_usage=True)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()},
"The LoRA params should be on 'meta' device.",
)
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder))
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder.parameters()},
"No param should be on 'meta' device.",
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
self.assertTrue(
"meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
)
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder." check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
self.assertTrue( self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder.parameters()}, "meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.", "The LoRA params should be on 'meta' device.",
) )
te_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder)) te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
set_peft_model_state_dict(pipe.text_encoder, te_state_dict, low_cpu_mem_usage=True) set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue( self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder.parameters()}, "meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.", "No param should be on 'meta' device.",
) )
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet _, _, inputs = self.get_dummy_inputs()
inject_adapter_in_model(denoiser_lora_config, denoiser, low_cpu_mem_usage=True) output_lora = pipe(**inputs)[0]
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(output_lora.shape == self.output_shape)
self.assertTrue(
"meta" in {p.device.type for p in denoiser.parameters()}, "The LoRA params should be on 'meta' device."
)
denoiser_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(denoiser))
set_peft_model_state_dict(denoiser, denoiser_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in denoiser.parameters()}, "No param should be on 'meta' device."
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
inject_adapter_in_model(text_lora_config, pipe.text_encoder_2, low_cpu_mem_usage=True)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
self.assertTrue(
"meta" in {p.device.type for p in pipe.text_encoder_2.parameters()},
"The LoRA params should be on 'meta' device.",
)
te2_state_dict = initialize_dummy_state_dict(get_peft_model_state_dict(pipe.text_encoder_2))
set_peft_model_state_dict(pipe.text_encoder_2, te2_state_dict, low_cpu_mem_usage=True)
self.assertTrue(
"meta" not in {p.device.type for p in pipe.text_encoder_2.parameters()},
"No param should be on 'meta' device.",
)
_, _, inputs = self.get_dummy_inputs()
output_lora = pipe(**inputs)[0]
self.assertTrue(output_lora.shape == self.output_shape)
@require_peft_version_greater("0.13.1") @require_peft_version_greater("0.13.1")
@require_transformers_version_greater("4.45.2") @require_transformers_version_greater("4.45.2")
def test_low_cpu_mem_usage_with_loading(self): def test_low_cpu_mem_usage_with_loading(self):
"""Tests if we can load LoRA state dict with low_cpu_mem_usage.""" """Tests if we can load LoRA state dict with low_cpu_mem_usage."""
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save) lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights( self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=False)
for module_name, module in modules_to_save.items(): for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
# Now, check for `low_cpu_mem_usage.` # Now, check for `low_cpu_mem_usage.`
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"), low_cpu_mem_usage=True)
for module_name, module in modules_to_save.items(): for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora_from_pretrained_low_cpu = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
images_lora_from_pretrained_low_cpu, images_lora_from_pretrained, atol=1e-3, rtol=1e-3 "Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
), )
"Loading from saved checkpoints with `low_cpu_mem_usage` should give same results.",
)
def test_simple_inference_with_text_lora_and_scale(self): def test_simple_inference_with_text_lora_and_scale(self):
""" """
...@@ -472,147 +460,140 @@ class PeftLoraLoaderMixinTests: ...@@ -472,147 +460,140 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
) )
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue( self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output", "Lora + scale should change the output",
) )
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA", "Lora + 0 scale should lead to same result as no LoRA",
) )
def test_simple_inference_with_text_lora_fused(self): def test_simple_inference_with_text_lora_fused(self):
""" """
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.fuse_lora() pipe.fuse_lora()
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] ouput_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
) )
def test_simple_inference_with_text_lora_unloaded(self): def test_simple_inference_with_text_lora_unloaded(self):
""" """
Tests a simple inference with lora attached to text encoder, then unloads the lora weights Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
self.assertFalse( self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertFalse( self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2), check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2", "Lora not correctly unloaded in text encoder 2",
) )
ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] ouput_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), np.allclose(ouput_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output", "Fused lora should change the output",
) )
def test_simple_inference_with_text_lora_save_load(self): def test_simple_inference_with_text_lora_save_load(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe) modules_to_save = self._get_modules_to_save(pipe)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save) lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights( self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items(): for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_partial_text_lora(self): def test_simple_inference_with_partial_text_lora(self):
""" """
...@@ -620,142 +601,139 @@ class PeftLoraLoaderMixinTests: ...@@ -620,142 +601,139 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed with different ranks and some adapters removed
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls) # Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324). text_lora_config = LoraConfig(
text_lora_config = LoraConfig( r=4,
r=4, rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)},
rank_pattern={self.text_encoder_target_modules[i]: i + 1 for i in range(3)}, lora_alpha=4,
lora_alpha=4, target_modules=self.text_encoder_target_modules,
target_modules=self.text_encoder_target_modules, init_lora_weights=False,
init_lora_weights=False, use_dora=False,
use_dora=False, )
) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
state_dict = {} state_dict = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
# Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder` # Gather the state dict for the PEFT model, excluding `layers.4`, to ensure `load_lora_into_text_encoder`
# supports missing layers (PR#8324). # supports missing layers (PR#8324).
state_dict = { state_dict = {
f"text_encoder.{module_name}": param f"text_encoder.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items() for module_name, param in get_peft_model_state_dict(pipe.text_encoder).items()
if "text_model.encoder.layers.4" not in module_name if "text_model.encoder.layers.4" not in module_name
} }
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
state_dict.update( state_dict.update(
{ {
f"text_encoder_2.{module_name}": param f"text_encoder_2.{module_name}": param
for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items() for module_name, param in get_peft_model_state_dict(pipe.text_encoder_2).items()
if "text_model.encoder.layers.4" not in module_name if "text_model.encoder.layers.4" not in module_name
} }
) )
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output" not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
) )
# Unload lora and load it back using the pipe.load_lora_weights machinery # Unload lora and load it back using the pipe.load_lora_weights machinery
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(state_dict) pipe.load_lora_weights(state_dict)
output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_partial_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3), not np.allclose(output_partial_lora, output_lora, atol=1e-3, rtol=1e-3),
"Removing adapters should change the output", "Removing adapters should change the output",
) )
def test_simple_inference_save_pretrained_with_text_lora(self): def test_simple_inference_save_pretrained_with_text_lora(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, _ = self.get_dummy_components()
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname)
pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname) pipe_from_pretrained = self.pipeline_class.from_pretrained(tmpdirname)
pipe_from_pretrained.to(torch_device) pipe_from_pretrained.to(torch_device)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder),
"Lora not correctly set in text encoder",
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder), check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
"Lora not correctly set in text encoder", "Lora not correctly set in text encoder 2",
) )
if self.has_two_text_encoders or self.has_three_text_encoders: images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
images_lora_save_pretrained = pipe_from_pretrained(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_save_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_text_denoiser_lora_save_load(self): def test_simple_inference_with_text_denoiser_lora_save_load(self):
""" """
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save) lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights( self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts save_directory=tmpdirname, safe_serialization=False, **lora_state_dicts
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
pipe.unload_lora_weights() pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin")) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
for module_name, module in modules_to_save.items(): for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}") self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0] images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3), np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.", "Loading from saved checkpoints should give same results.",
) )
def test_simple_inference_with_text_denoiser_lora_and_scale(self): def test_simple_inference_with_text_denoiser_lora_and_scale(self):
""" """
...@@ -763,120 +741,112 @@ class PeftLoraLoaderMixinTests: ...@@ -763,120 +741,112 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
for scheduler_cls in self.scheduler_classes: output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] attention_kwargs = {attention_kwargs_name: {"scale": 0.5}}
self.assertTrue( output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.5}} self.assertTrue(
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue( attention_kwargs = {attention_kwargs_name: {"scale": 0.0}}
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3), output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
"Lora + scale should change the output",
)
attention_kwargs = {attention_kwargs_name: {"scale": 0.0}} self.assertTrue(
output_lora_0_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3), pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
"Lora + 0 scale should lead to same result as no LoRA", "The scaling parameter has not been correctly restored!",
) )
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
pipe.text_encoder.text_model.encoder.layers[0].self_attn.q_proj.scaling["default"] == 1.0,
"The scaling parameter has not been correctly restored!",
)
def test_simple_inference_with_text_lora_denoiser_fused(self): def test_simple_inference_with_text_lora_denoiser_fused(self):
""" """
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and makes sure it works as expected - with unet
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
# Fusing should still keep the LoRA layers # Fusing should still keep the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] output_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output" np.allclose(output_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
) )
def test_simple_inference_with_text_denoiser_lora_unloaded(self): def test_simple_inference_with_text_denoiser_lora_unloaded(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape) self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unload_lora_weights() pipe.unload_lora_weights()
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
self.assertFalse( self.assertFalse(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder" self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
)
self.assertFalse(check_if_lora_correctly_set(denoiser), "Lora not correctly unloaded in denoiser")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertFalse( self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2), check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2", "Lora not correctly unloaded in text encoder 2",
) )
output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0] output_unloaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3), np.allclose(output_unloaded, output_no_lora, atol=1e-3, rtol=1e-3),
"Fused lora should change the output", "Fused lora should change the output",
) )
def test_simple_inference_with_text_denoiser_lora_unfused( def test_simple_inference_with_text_denoiser_lora_unfused(
self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3 self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3
...@@ -885,125 +855,120 @@ class PeftLoraLoaderMixinTests: ...@@ -885,125 +855,120 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_fused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
# unloading should remove the LoRA layers # unloading should remove the LoRA layers
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers") self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers" check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
) )
# Fuse and unfuse should lead to the same results # Fuse and unfuse should lead to the same results
self.assertTrue( self.assertTrue(
np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol), np.allclose(output_fused_lora, output_unfused_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output", "Fused lora should not change the output",
) )
def test_simple_inference_with_text_denoiser_multi_adapter(self): def test_simple_inference_with_text_denoiser_multi_adapter(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
pipe.set_adapters("adapter-1") pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_adapter_1, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.", "Adapter outputs should be different.",
) )
pipe.set_adapters("adapter-2") pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.", "Adapter outputs should be different.",
) )
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter outputs should be different.", "Adapter outputs should be different.",
) )
# Fuse and unfuse should lead to the same results # Fuse and unfuse should lead to the same results
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results", "Adapter 1 and 2 should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results", "Adapter 1 and mixed adapters should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results", "Adapter 2 and mixed adapters should give different results",
) )
pipe.disable_lora() pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_wrong_adapter_name_raises_error(self): def test_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1" adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1024,8 +989,7 @@ class PeftLoraLoaderMixinTests:
def test_multiple_wrong_adapter_name_raises_error(self): def test_multiple_wrong_adapter_name_raises_error(self):
adapter_name = "adapter-1" adapter_name = "adapter-1"
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1054,131 +1018,127 @@ class PeftLoraLoaderMixinTests: ...@@ -1054,131 +1018,127 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set different weights for different blocks (i.e. block lora) one adapter and set different weights for different blocks (i.e. block lora)
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
weights_1 = {"text_encoder": 2, "unet": {"down": 5}} weights_1 = {"text_encoder": 2, "unet": {"down": 5}}
pipe.set_adapters("adapter-1", weights_1) pipe.set_adapters("adapter-1", weights_1)
output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_weights_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
weights_2 = {"unet": {"up": 5}} weights_2 = {"unet": {"up": 5}}
pipe.set_adapters("adapter-1", weights_2) pipe.set_adapters("adapter-1", weights_2)
output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_weights_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3), np.allclose(output_weights_1, output_weights_2, atol=1e-3, rtol=1e-3),
"LoRA weights 1 and 2 should give different results", "LoRA weights 1 and 2 should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_weights_1, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 1 should give different results", "No adapter and LoRA weights 1 should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_weights_2, atol=1e-3, rtol=1e-3),
"No adapter and LoRA weights 2 should give different results", "No adapter and LoRA weights 2 should give different results",
) )
pipe.disable_lora() pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set different weights for different blocks (i.e. block lora) multiple adapters and set different weights for different blocks (i.e. block lora)
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules: if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
scales_1 = {"text_encoder": 2, "unet": {"down": 5}} scales_1 = {"text_encoder": 2, "unet": {"down": 5}}
scales_2 = {"unet": {"down": 5, "mid": 5}} scales_2 = {"unet": {"down": 5, "mid": 5}}
pipe.set_adapters("adapter-1", scales_1) pipe.set_adapters("adapter-1", scales_1)
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2", scales_2) pipe.set_adapters("adapter-2", scales_2)
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2]) pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1, scales_2])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Fuse and unfuse should lead to the same results # Fuse and unfuse should lead to the same results
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results", "Adapter 1 and 2 should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results", "Adapter 1 and mixed adapters should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results", "Adapter 2 and mixed adapters should give different results",
) )
pipe.disable_lora() pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
# a mismatching number of adapter_names and adapter_weights should raise an error # a mismatching number of adapter_names and adapter_weights should raise an error
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1]) pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self): def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter""" """Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
...@@ -1274,170 +1234,164 @@ class PeftLoraLoaderMixinTests: ...@@ -1274,170 +1234,164 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them multiple adapters and set/delete them
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components: if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
pipe.set_adapters("adapter-1") pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2") pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results", "Adapter 1 and 2 should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results", "Adapter 1 and mixed adapters should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results", "Adapter 2 and mixed adapters should give different results",
) )
pipe.delete_adapters("adapter-1") pipe.delete_adapters("adapter-1")
output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_deleted_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), np.allclose(output_deleted_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results", "Adapter 1 and 2 should give different results",
) )
pipe.delete_adapters("adapter-2") pipe.delete_adapters("adapter-2")
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"]) pipe.delete_adapters(["adapter-1", "adapter-2"])
output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0] output_deleted_adapters = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_deleted_adapters, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
""" """
Tests a simple inference with lora attached to text encoder and unet, attaches Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them multiple adapters and set them
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
denoiser.add_adapter(denoiser_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components: if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
pipe.set_adapters("adapter-1") pipe.set_adapters("adapter-1")
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters("adapter-2") pipe.set_adapters("adapter-2")
output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_2 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_mixed = pipe(**inputs, generator=torch.manual_seed(0))[0]
# Fuse and unfuse should lead to the same results # Fuse and unfuse should lead to the same results
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_2, atol=1e-3, rtol=1e-3),
"Adapter 1 and 2 should give different results", "Adapter 1 and 2 should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_1, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 1 and mixed adapters should give different results", "Adapter 1 and mixed adapters should give different results",
) )
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_2, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Adapter 2 and mixed adapters should give different results", "Adapter 2 and mixed adapters should give different results",
) )
pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6]) pipe.set_adapters(["adapter-1", "adapter-2"], [0.5, 0.6])
output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_mixed_weighted = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3), np.allclose(output_adapter_mixed_weighted, output_adapter_mixed, atol=1e-3, rtol=1e-3),
"Weighted adapter and mixed adapter should give different results", "Weighted adapter and mixed adapter should give different results",
) )
pipe.disable_lora() pipe.disable_lora()
output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0] output_disabled = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3), np.allclose(output_no_lora, output_disabled, atol=1e-3, rtol=1e-3),
"output with no lora and output with lora disabled should give same results", "output with no lora and output with lora disabled should give same results",
) )
@skip_mps @skip_mps
@pytest.mark.xfail( @pytest.mark.xfail(
...@@ -1445,164 +1399,157 @@ class PeftLoraLoaderMixinTests: ...@@ -1445,164 +1399,157 @@ class PeftLoraLoaderMixinTests:
reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.", reason="Test currently fails on CPU and PyTorch 2.5.1 but not on PyTorch 2.4.1.",
strict=False, strict=False,
) )
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
# corrupt one LoRA weight with `inf` values # corrupt one LoRA weight with `inf` values
with torch.no_grad(): with torch.no_grad():
if self.unet_kwargs: if self.unet_kwargs:
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[ pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"adapter-1" "inf"
].weight += float("inf") )
else: else:
named_modules = [name for name, _ in pipe.transformer.named_modules()] named_modules = [name for name, _ in pipe.transformer.named_modules()]
possible_tower_names = [ possible_tower_names = [
"transformer_blocks", "transformer_blocks",
"blocks", "blocks",
"joint_transformer_blocks", "joint_transformer_blocks",
"single_transformer_blocks", "single_transformer_blocks",
] ]
filtered_tower_names = [ filtered_tower_names = [
tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name) tower_name for tower_name in possible_tower_names if hasattr(pipe.transformer, tower_name)
] ]
if len(filtered_tower_names) == 0: if len(filtered_tower_names) == 0:
reason = ( reason = f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}."
f"`pipe.transformer` didn't have any of the following attributes: {possible_tower_names}." raise ValueError(reason)
) for tower_name in filtered_tower_names:
raise ValueError(reason) transformer_tower = getattr(pipe.transformer, tower_name)
for tower_name in filtered_tower_names: has_attn1 = any("attn1" in name for name in named_modules)
transformer_tower = getattr(pipe.transformer, tower_name) if has_attn1:
has_attn1 = any("attn1" in name for name in named_modules) transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf")
if has_attn1: else:
transformer_tower[0].attn1.to_q.lora_A["adapter-1"].weight += float("inf") transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
else:
transformer_tower[0].attn.to_q.lora_A["adapter-1"].weight += float("inf") # with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
# with `safe_fusing=True` we should see an Error pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True)
with self.assertRaises(ValueError):
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=True) # without we should not see an error, but every image will be black
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
# without we should not see an error, but every image will be black out = pipe(**inputs)[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, safe_fusing=False)
out = pipe(**inputs)[0] self.assertTrue(np.isnan(out).all())
self.assertTrue(np.isnan(out).all())
def test_get_adapters(self): def test_get_adapters(self):
""" """
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"]) self.assertListEqual(adapter_names, ["adapter-1"])
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser.add_adapter(denoiser_lora_config, "adapter-2") denoiser.add_adapter(denoiser_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters() adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"]) self.assertListEqual(adapter_names, ["adapter-2"])
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"]) self.assertListEqual(pipe.get_active_adapters(), ["adapter-1", "adapter-2"])
def test_get_list_adapters(self): def test_get_list_adapters(self):
""" """
Tests a simple usecase where we attach multiple adapters and check if the results Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results are the expected results
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None)
# 1. # 1.
dicts_to_be_checked = {} dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
dicts_to_be_checked = {"text_encoder": ["adapter-1"]} dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1") pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"unet": ["adapter-1"]}) dicts_to_be_checked.update({"unet": ["adapter-1"]})
else: else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
dicts_to_be_checked.update({"transformer": ["adapter-1"]}) dicts_to_be_checked.update({"transformer": ["adapter-1"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
# 2. # 2.
dicts_to_be_checked = {} dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2") pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2") pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else: else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2") pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
# 3. # 3.
pipe.set_adapters(["adapter-1", "adapter-2"]) pipe.set_adapters(["adapter-1", "adapter-2"])
dicts_to_be_checked = {} dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]}) dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else: else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]}) dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual( self.assertDictEqual(
pipe.get_list_adapters(), pipe.get_list_adapters(),
dicts_to_be_checked, dicts_to_be_checked,
) )
# 4. # 4.
dicts_to_be_checked = {} dicts_to_be_checked = {}
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]} dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3") pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]}) dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
else: else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3") pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]}) dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked) self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
@require_peft_version_greater(peft_version="0.6.2") @require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi( def test_simple_inference_with_text_lora_denoiser_fused_multi(
...@@ -1612,8 +1559,83 @@ class PeftLoraLoaderMixinTests: ...@@ -1612,8 +1559,83 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case and makes sure it works as expected - with unet and multi-adapter case
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1"])
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"])
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"])
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1627,150 +1649,65 @@ class PeftLoraLoaderMixinTests: ...@@ -1627,150 +1649,65 @@ class PeftLoraLoaderMixinTests:
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder" check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
) )
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1") denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
denoiser.add_adapter(denoiser_lora_config, "adapter-2")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components: if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1") pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly set in text encoder 2",
) )
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
# set them to multi-adapter inference mode
pipe.set_adapters(["adapter-1", "adapter-2"])
outputs_all_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
pipe.set_adapters(["adapter-1"]) pipe.set_adapters(["adapter-1"])
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
pipe.fuse_lora(components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-1"]) pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers so output should remain the same
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0] outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol), np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output", "Fused lora should not change the output",
) )
self.assertFalse(
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules) np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}") "LoRA should change the output",
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules, adapter_names=["adapter-2", "adapter-1"]
)
self.assertTrue(pipe.num_fused_loras == 2, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
# Fusing should still keep the LoRA layers
output_all_lora_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_all_lora_fused, outputs_all_lora, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
) )
pipe.unfuse_lora(components=self.pipeline_class._lora_loadable_modules)
self.assertTrue(pipe.num_fused_loras == 0, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expected_rtol: float = 1e-3):
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
for lora_scale in [1.0, 0.8]:
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly set in text encoder 2",
)
pipe.set_adapters(["adapter-1"])
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
outputs_lora_1 = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
pipe.fuse_lora(
components=self.pipeline_class._lora_loadable_modules,
adapter_names=["adapter-1"],
lora_scale=lora_scale,
)
self.assertTrue(pipe.num_fused_loras == 1, f"{pipe.num_fused_loras=}, {pipe.fused_loras=}")
outputs_lora_1_fused = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(outputs_lora_1, outputs_lora_1_fused, atol=expected_atol, rtol=expected_rtol),
"Fused lora should not change the output",
)
self.assertFalse(
np.allclose(output_no_lora, outputs_lora_1, atol=expected_atol, rtol=expected_rtol),
"LoRA should change the output",
)
@require_peft_version_greater(peft_version="0.9.0") @require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self): def test_simple_inference_with_dora(self):
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components(use_dora=True)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components( pipe = self.pipeline_class(**components)
scheduler_cls, use_dora=True pipe = pipe.to(torch_device)
) pipe.set_progress_bar_config(disable=None)
pipe = self.pipeline_class(**components) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_dora_lora.shape == self.output_shape) self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse( self.assertFalse(
np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3), np.allclose(output_dora_lora, output_no_dora_lora, atol=1e-3, rtol=1e-3),
"DoRA lora should change the output", "DoRA lora should change the output",
) )
def test_missing_keys_warning(self): def test_missing_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`. # Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests: ...@@ -1805,9 +1742,8 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", "")) self.assertTrue(missing_key.replace(f"{component}.", "") in cap_logger.out.replace("default_0.", ""))
def test_unexpected_keys_warning(self): def test_unexpected_keys_warning(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`. # Skip text encoder check for now as that is handled with `transformers`.
components, _, denoiser_lora_config = self.get_dummy_components(scheduler_cls) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1842,23 +1778,22 @@ class PeftLoraLoaderMixinTests: ...@@ -1842,23 +1778,22 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True) pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True) pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True) pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works.. # Just makes sure it works.
_ = pipe(**inputs, generator=torch.manual_seed(0))[0] _ = pipe(**inputs, generator=torch.manual_seed(0))[0]
def test_modify_padding_mode(self): def test_modify_padding_mode(self):
def set_pad_mode(network, mode="circular"): def set_pad_mode(network, mode="circular"):
...@@ -1866,22 +1801,20 @@ class PeftLoraLoaderMixinTests: ...@@ -1866,22 +1801,20 @@ class PeftLoraLoaderMixinTests:
if isinstance(module, torch.nn.Conv2d): if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode module.padding_mode = mode
for scheduler_cls in self.scheduler_classes: components, _, _ = self.get_dummy_components()
components, _, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _pad_mode = "circular"
_pad_mode = "circular" set_pad_mode(pipe.vae, _pad_mode)
set_pad_mode(pipe.vae, _pad_mode) set_pad_mode(pipe.unet, _pad_mode)
set_pad_mode(pipe.unet, _pad_mode)
_, _, inputs = self.get_dummy_inputs() _, _, inputs = self.get_dummy_inputs()
_ = pipe(**inputs)[0] _ = pipe(**inputs)[0]
def test_logs_info_when_no_lora_keys_found(self): def test_logs_info_when_no_lora_keys_found(self):
scheduler_cls = self.scheduler_classes[0]
# Skip text encoder check for now as that is handled with `transformers`. # Skip text encoder check for now as that is handled with `transformers`.
components, _, _ = self.get_dummy_components(scheduler_cls) components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1925,73 +1858,71 @@ class PeftLoraLoaderMixinTests: ...@@ -1925,73 +1858,71 @@ class PeftLoraLoaderMixinTests:
def test_set_adapters_match_attention_kwargs(self): def test_set_adapters_match_attention_kwargs(self):
"""Test to check if outputs after `set_adapters()` and attention kwargs match.""" """Test to check if outputs after `set_adapters()` and attention kwargs match."""
attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class) attention_kwargs_name = determine_attention_kwargs_name(self.pipeline_class)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
lora_scale = 0.5
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}}
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertFalse(
np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
pipe.set_adapters("default", lora_scale)
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3),
"Lora + scale should match the output of `set_adapters()`.",
)
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
for scheduler_cls in self.scheduler_classes: self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False) pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
lora_scale = 0.5 output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
attention_kwargs = {attention_kwargs_name: {"scale": lora_scale}} self.assertTrue(
output_lora_scale = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0] not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
self.assertFalse(
np.allclose(output_no_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output", "Lora + scale should change the output",
) )
pipe.set_adapters("default", lora_scale)
output_lora_scale_wo_kwargs = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue( self.assertTrue(
not np.allclose(output_no_lora, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output", "Loading from saved checkpoints should give same results as attention_kwargs.",
) )
self.assertTrue( self.assertTrue(
np.allclose(output_lora_scale, output_lora_scale_wo_kwargs, atol=1e-3, rtol=1e-3), np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should match the output of `set_adapters()`.", "Loading from saved checkpoints should give same results as set_adapters().",
) )
with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
for module_name, module in modules_to_save.items():
self.assertTrue(check_if_lora_correctly_set(module), f"Lora not correctly set in {module_name}")
output_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0), **attention_kwargs)[0]
self.assertTrue(
not np.allclose(output_no_lora, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
self.assertTrue(
np.allclose(output_lora_scale, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as attention_kwargs.",
)
self.assertTrue(
np.allclose(output_lora_scale_wo_kwargs, output_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results as set_adapters().",
)
@require_peft_version_greater("0.13.2") @require_peft_version_greater("0.13.2")
def test_lora_B_bias(self): def test_lora_B_bias(self):
# Currently, this test is only relevant for Flux Control LoRA as we are not # Currently, this test is only relevant for Flux Control LoRA as we are not
# aware of any other LoRA checkpoint that has its `lora_B` biases trained. # aware of any other LoRA checkpoint that has its `lora_B` biases trained.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2028,7 +1959,7 @@ class PeftLoraLoaderMixinTests:
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
def test_correct_lora_configs_with_different_ranks(self): def test_correct_lora_configs_with_different_ranks(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2114,7 +2045,7 @@ class PeftLoraLoaderMixinTests:
self.assertEqual(submodule.bias.dtype, dtype_to_check) self.assertEqual(submodule.bias.dtype, dtype_to_check)
def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32): def initialize_pipeline(storage_dtype=None, compute_dtype=torch.float32):
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2181,7 +2112,7 @@ class PeftLoraLoaderMixinTests:
self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None) self.assertTrue(module._diffusers_hook.get_hook(_PEFT_AUTOCAST_DISABLE_HOOK) is not None)
# 1. Test forward with add_adapter # 1. Test forward with add_adapter
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2211,7 +2142,7 @@ class PeftLoraLoaderMixinTests:
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device, dtype=compute_dtype) pipe = pipe.to(torch_device, dtype=compute_dtype)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2231,10 +2162,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16]) @parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha): def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe, _ = self.add_adapters_to_pipeline( pipe, _ = self.add_adapters_to_pipeline(
...@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2280,10 +2208,7 @@ class PeftLoraLoaderMixinTests:
@parameterized.expand([4, 8, 16]) @parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components(lora_alpha=lora_alpha)
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
...@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2311,8 +2236,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_unload_add_adapter(self): def test_lora_unload_add_adapter(self):
"""Tests if `unload_lora_weights()` -> `add_adapter()` works.""" """Tests if `unload_lora_weights()` -> `add_adapter()` works."""
scheduler_cls = self.scheduler_classes[0] components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components).to(torch_device) pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
...@@ -2330,51 +2254,48 @@ class PeftLoraLoaderMixinTests: ...@@ -2330,51 +2254,48 @@ class PeftLoraLoaderMixinTests:
def test_inference_load_delete_load_adapters(self): def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes: components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components)
pipe = self.pipeline_class(**components) pipe = pipe.to(torch_device)
pipe = pipe.to(torch_device) pipe.set_progress_bar_config(disable=None)
pipe.set_progress_bar_config(disable=None) _, _, inputs = self.get_dummy_inputs(with_generator=False)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
if "text_encoder" in self.pipeline_class._lora_loadable_modules: if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config) pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder"
)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
denoiser.add_adapter(denoiser_lora_config) denoiser.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.") self.assertTrue(check_if_lora_correctly_set(denoiser), "Lora not correctly set in denoiser.")
if self.has_two_text_encoders or self.has_three_text_encoders: if self.has_two_text_encoders or self.has_three_text_encoders:
lora_loadable_components = self.pipeline_class._lora_loadable_modules lora_loadable_components = self.pipeline_class._lora_loadable_modules
if "text_encoder_2" in lora_loadable_components: if "text_encoder_2" in lora_loadable_components:
pipe.text_encoder_2.add_adapter(text_lora_config) pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue( self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2" check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
) )
output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0] output_adapter_1 = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True) modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save) lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts) self.pipeline_class.save_lora_weights(save_directory=tmpdirname, **lora_state_dicts)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
# First, delete adapter and compare. # First, delete adapter and compare.
pipe.delete_adapters(pipe.get_active_adapters()[0]) pipe.delete_adapters(pipe.get_active_adapters()[0])
output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0] output_no_adapter = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3)) self.assertFalse(np.allclose(output_adapter_1, output_no_adapter, atol=1e-3, rtol=1e-3))
self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(output_no_lora, output_no_adapter, atol=1e-3, rtol=1e-3))
# Then load adapter and compare. # Then load adapter and compare.
pipe.load_lora_weights(tmpdirname) pipe.load_lora_weights(tmpdirname)
output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0] output_lora_loaded = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(output_adapter_1, output_lora_loaded, atol=1e-3, rtol=1e-3))
def _test_group_offloading_inference_denoiser(self, offload_type, use_stream): def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook from diffusers.hooks.group_offloading import _get_top_level_group_offload_hook
...@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2382,7 +2303,7 @@ class PeftLoraLoaderMixinTests:
onload_device = torch_device onload_device = torch_device
offload_device = torch.device("cpu") offload_device = torch.device("cpu")
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, text_lora_config, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2399,7 +2320,7 @@ class PeftLoraLoaderMixinTests:
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
components, _, _ = self.get_dummy_components(self.scheduler_classes[0]) components, _, _ = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet denoiser = pipe.transformer if self.unet_kwargs is None else pipe.unet
...@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2451,7 +2372,7 @@ class PeftLoraLoaderMixinTests:
@require_torch_accelerator @require_torch_accelerator
def test_lora_loading_model_cpu_offload(self): def test_lora_loading_model_cpu_offload(self):
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device) pipe = pipe.to(torch_device)
...@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests: ...@@ -2470,7 +2391,7 @@ class PeftLoraLoaderMixinTests:
save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts save_directory=tmpdirname, safe_serialization=True, **lora_state_dicts
) )
# reinitialize the pipeline to mimic the inference workflow. # reinitialize the pipeline to mimic the inference workflow.
components, _, denoiser_lora_config = self.get_dummy_components(self.scheduler_classes[0]) components, _, denoiser_lora_config = self.get_dummy_components()
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe.enable_model_cpu_offload(device=torch_device) pipe.enable_model_cpu_offload(device=torch_device)
pipe.load_lora_weights(tmpdirname) pipe.load_lora_weights(tmpdirname)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment