Unverified Commit 62863bb1 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

Revert "[LoRA] introduce LoraBaseMixin to promote reusability." (#8976)

Revert "[LoRA] introduce LoraBaseMixin to promote reusability. (#8774)"

This reverts commit 527430d0.
parent 1fd647f2
......@@ -11,7 +11,7 @@ from torch.nn.functional import grid_sample
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
......@@ -281,9 +281,7 @@ def create_motion_field_and_warp_latents(motion_field_strength_x, motion_field_s
return warped_latents
class TextToVideoZeroPipeline(
DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionLoraLoaderMixin
):
class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin):
r"""
Pipeline for zero-shot text-to-video generation using Stable Diffusion.
......@@ -833,7 +831,7 @@ class TextToVideoZeroPipeline(
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
......@@ -966,7 +964,7 @@ class TextToVideoZeroPipeline(
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
......
......@@ -14,7 +14,7 @@ from transformers import (
)
from ...image_processor import VaeImageProcessor
from ...loaders import StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
......@@ -422,7 +422,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
......@@ -555,7 +555,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if self.text_encoder is not None:
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
......
......@@ -20,7 +20,7 @@ import numpy as np
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from ...loaders import StableDiffusionLoraLoaderMixin
from ...loaders import LoraLoaderMixin
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
......@@ -62,7 +62,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
image_embeddings: Union[torch.Tensor, np.ndarray]
class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
"""
Pipeline for generating image prior for Wuerstchen.
......@@ -70,8 +70,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
- [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
Args:
prior ([`Prior`]):
......@@ -95,7 +95,6 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
text_encoder_name = "text_encoder"
model_cpu_offload_seq = "text_encoder->prior"
_callback_tensor_inputs = ["latents", "text_encoder_hidden_states", "negative_prompt_embeds"]
_lora_loadable_modules = ["prior", "text_encoder"]
def __init__(
self,
......
......@@ -12,55 +12,376 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import tempfile
import unittest
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
SD3Transformer2DModel,
StableDiffusion3Pipeline,
)
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
if is_peft_available():
pass
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
sys.path.append(".")
from utils import PeftLoraLoaderMixinTests # noqa: E402
from utils import check_if_lora_correctly_set # noqa: E402
@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class SD3LoRATests(unittest.TestCase):
pipeline_class = StableDiffusion3Pipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
transformer_kwargs = {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"attention_head_dim": 8,
"num_attention_heads": 4,
"caption_projection_dim": 32,
"joint_attention_dim": 32,
"pooled_projection_dim": 64,
"out_channels": 4,
}
vae_kwargs = {
"sample_size": 32,
"in_channels": 3,
"out_channels": 3,
"block_out_channels": (4,),
"layers_per_block": 1,
"latent_channels": 4,
"norm_num_groups": 1,
"use_quant_conv": False,
"use_post_quant_conv": False,
"shift_factor": 0.0609,
"scaling_factor": 1.5035,
}
has_three_text_encoders = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = SD3Transformer2DModel(
sample_size=32,
patch_size=1,
in_channels=4,
num_layers=1,
attention_head_dim=8,
num_attention_heads=4,
caption_projection_dim=32,
joint_attention_dim=32,
pooled_projection_dim=64,
out_channels=4,
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModelWithProjection(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = CLIPTextModelWithProjection(clip_text_encoder_config)
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"text_encoder_3": text_encoder_3,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"tokenizer_3": tokenizer_3,
"transformer": transformer,
"vae": vae,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
}
return inputs
def get_lora_config_for_transformer(self):
lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
return lora_config
def get_lora_config_for_text_encoders(self):
text_lora_config = LoraConfig(
r=4,
lora_alpha=4,
init_lora_weights="gaussian",
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
return text_lora_config
def test_simple_inference_with_transformer_lora_save_load(self):
components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.transformer.add_adapter(transformer_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
images_lora = pipe(**inputs).images
with tempfile.TemporaryDirectory() as tmpdirname:
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
transformer_lora_layers=transformer_state_dict,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
inputs = self.get_dummy_inputs(torch_device)
images_lora_from_pretrained = pipe(**inputs).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_clip_encoders_lora_save_load(self):
components = self.get_dummy_components()
transformer_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
pipe.transformer.add_adapter(transformer_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder.")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2.")
inputs = self.get_dummy_inputs(torch_device)
images_lora = pipe(**inputs).images
with tempfile.TemporaryDirectory() as tmpdirname:
transformer_state_dict = get_peft_model_state_dict(pipe.transformer)
text_encoder_one_state_dict = get_peft_model_state_dict(pipe.text_encoder)
text_encoder_two_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
transformer_lora_layers=transformer_state_dict,
text_encoder_lora_layers=text_encoder_one_state_dict,
text_encoder_2_lora_layers=text_encoder_two_state_dict,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.unload_lora_weights()
pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))
inputs = self.get_dummy_inputs(torch_device)
images_lora_from_pretrained = pipe(**inputs).images
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
self.assertTrue(
np.allclose(images_lora, images_lora_from_pretrained, atol=1e-3, rtol=1e-3),
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_transformer_lora_and_scale(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
output_lora = pipe(**inputs).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_clip_encoders_lora_and_scale(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
text_encoder_config = self.get_lora_config_for_text_encoders()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
pipe.text_encoder.add_adapter(text_encoder_config)
pipe.text_encoder_2.add_adapter(text_encoder_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text_encoder_one")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text_encoder_two")
inputs = self.get_dummy_inputs(torch_device)
output_lora = pipe(**inputs).images
self.assertTrue(
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.5}).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
inputs = self.get_dummy_inputs(torch_device)
output_lora_0_scale = pipe(**inputs, joint_attention_kwargs={"scale": 0.0}).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
)
def test_simple_inference_with_transformer_fused(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_transformer_fused_with_no_fusion(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_lora = pipe(**inputs).images
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
self.assertTrue(
np.allclose(ouput_fused, ouput_lora, atol=1e-3, rtol=1e-3),
"Fused lora output should be changed when LoRA isn't fused but still effective.",
)
def test_simple_inference_with_transformer_fuse_unfuse(self):
components = self.get_dummy_components()
transformer_lora_config = self.get_lora_config_for_transformer()
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(torch_device)
output_no_lora = pipe(**inputs).images
pipe.transformer.add_adapter(transformer_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
ouput_fused = pipe(**inputs).images
self.assertFalse(
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
pipe.unfuse_lora()
self.assertTrue(check_if_lora_correctly_set(pipe.transformer), "Lora not correctly set in transformer")
inputs = self.get_dummy_inputs(torch_device)
output_unfused_lora = pipe(**inputs).images
self.assertTrue(
np.allclose(ouput_fused, output_unfused_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
@require_torch_gpu
def test_sd3_lora(self):
......
......@@ -19,14 +19,12 @@ from itertools import product
import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import (
AutoencoderKL,
DDIMScheduler,
FlowMatchEulerDiscreteScheduler,
LCMScheduler,
SD3Transformer2DModel,
UNet2DConditionModel,
)
from diffusers.utils.import_utils import is_peft_available
......@@ -73,47 +71,28 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = None
scheduler_kwargs = None
has_two_text_encoders = False
has_three_text_encoders = False
unet_kwargs = None
transformer_kwargs = None
vae_kwargs = None
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
raise ValueError("Both `has_two_text_encoders` and `has_three_text_encoders` cannot be True.")
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4
torch.manual_seed(0)
if self.unet_kwargs is not None:
unet = UNet2DConditionModel(**self.unet_kwargs)
else:
transformer = SD3Transformer2DModel(**self.transformer_kwargs)
unet = UNet2DConditionModel(**self.unet_kwargs)
scheduler = scheduler_cls(**self.scheduler_kwargs)
torch.manual_seed(0)
vae = AutoencoderKL(**self.vae_kwargs)
if not self.has_three_text_encoders:
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
text_encoder = CLIPTextModel.from_pretrained("peft-internal-testing/tiny-clip-text-2")
tokenizer = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
if self.has_two_text_encoders:
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("peft-internal-testing/tiny-clip-text-2")
tokenizer_2 = CLIPTokenizer.from_pretrained("peft-internal-testing/tiny-clip-text-2")
if self.has_three_text_encoders:
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_3 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained("hf-internal-testing/tiny-sd3-text_encoder-2")
text_encoder_3 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
......@@ -122,7 +101,7 @@ class PeftLoraLoaderMixinTests:
use_dora=use_dora,
)
denoiser_lora_config = LoraConfig(
unet_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
......@@ -130,31 +109,18 @@ class PeftLoraLoaderMixinTests:
use_dora=use_dora,
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.unet_kwargs is not None:
pipeline_components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
}
elif self.has_three_text_encoders and self.transformer_kwargs is not None:
pipeline_components = {
"transformer": transformer,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"text_encoder_3": text_encoder_3,
"tokenizer_3": tokenizer_3,
}
if self.has_two_text_encoders:
pipeline_components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"text_encoder_2": text_encoder_2,
"tokenizer_2": tokenizer_2,
"image_encoder": None,
"feature_extractor": None,
}
else:
pipeline_components = {
"unet": unet,
......@@ -167,7 +133,7 @@ class PeftLoraLoaderMixinTests:
"image_encoder": None,
}
return pipeline_components, text_lora_config, denoiser_lora_config
return pipeline_components, text_lora_config, unet_lora_config
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
......@@ -204,12 +170,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple inference and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -217,20 +178,14 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
def test_simple_inference_with_text_lora(self):
"""
Tests a simple inference with lora attached on the text encoder
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -238,13 +193,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -260,12 +214,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached on the text encoder + scale argument
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -273,13 +222,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -290,27 +238,17 @@ class PeftLoraLoaderMixinTests:
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
if self.unet_kwargs is not None:
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
else:
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}
).images
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
if self.unet_kwargs is not None:
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
else:
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}
).images
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
......@@ -321,12 +259,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -334,13 +267,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -350,7 +282,7 @@ class PeftLoraLoaderMixinTests:
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
......@@ -365,12 +297,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference with lora attached to text encoder, then unloads the lora weights
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -378,13 +305,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -396,7 +322,7 @@ class PeftLoraLoaderMixinTests:
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2",
......@@ -412,12 +338,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA.
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -425,13 +346,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -441,7 +361,7 @@ class PeftLoraLoaderMixinTests:
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
self.pipeline_class.save_lora_weights(
......@@ -465,7 +385,7 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
......@@ -481,14 +401,9 @@ class PeftLoraLoaderMixinTests:
with different ranks and some adapters removed
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, _ = self.get_dummy_components(scheduler_cls)
# Verify `StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
# Verify `LoraLoaderMixin.load_lora_into_text_encoder` handles different ranks per module (PR#8324).
text_lora_config = LoraConfig(
r=4,
rank_pattern={"q_proj": 1, "k_proj": 2, "v_proj": 3},
......@@ -503,8 +418,7 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
......@@ -516,7 +430,7 @@ class PeftLoraLoaderMixinTests:
if "text_model.encoder.layers.4" not in module_name
}
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -548,12 +462,7 @@ class PeftLoraLoaderMixinTests:
"""
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......@@ -561,13 +470,12 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -586,7 +494,7 @@ class PeftLoraLoaderMixinTests:
"Lora not correctly set in text encoder",
)
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertTrue(
check_if_lora_correctly_set(pipe_from_pretrained.text_encoder_2),
"Lora not correctly set in text encoder 2",
......@@ -599,42 +507,27 @@ class PeftLoraLoaderMixinTests:
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_text_denoiser_lora_save_load(self):
def test_simple_inference_with_text_unet_lora_save_load(self):
"""
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in Unet")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -644,36 +537,22 @@ class PeftLoraLoaderMixinTests:
with tempfile.TemporaryDirectory() as tmpdirname:
text_encoder_state_dict = get_peft_model_state_dict(pipe.text_encoder)
if self.unet_kwargs is not None:
denoiser_state_dict = get_peft_model_state_dict(pipe.unet)
else:
denoiser_state_dict = get_peft_model_state_dict(pipe.transformer)
if self.has_two_text_encoders or self.has_three_text_encoders:
unet_state_dict = get_peft_model_state_dict(pipe.unet)
if self.has_two_text_encoders:
text_encoder_2_state_dict = get_peft_model_state_dict(pipe.text_encoder_2)
if self.unet_kwargs is not None:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_layers=denoiser_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
transformer_lora_layers=denoiser_state_dict,
safe_serialization=False,
)
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
text_encoder_2_lora_layers=text_encoder_2_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
)
else:
self.pipeline_class.save_lora_weights(
save_directory=tmpdirname,
text_encoder_lora_layers=text_encoder_state_dict,
unet_lora_layers=denoiser_state_dict,
unet_lora_layers=unet_state_dict,
safe_serialization=False,
)
......@@ -684,10 +563,9 @@ class PeftLoraLoaderMixinTests:
images_lora_from_pretrained = pipe(**inputs, generator=torch.manual_seed(0)).images
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
......@@ -697,37 +575,27 @@ class PeftLoraLoaderMixinTests:
"Loading from saved checkpoints should give same results.",
)
def test_simple_inference_with_text_denoiser_lora_and_scale(self):
def test_simple_inference_with_text_unet_lora_and_scale(self):
"""
Tests a simple inference with lora attached on the text encoder + Unet + scale argument
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -738,27 +606,17 @@ class PeftLoraLoaderMixinTests:
not np.allclose(output_lora, output_no_lora, atol=1e-3, rtol=1e-3), "Lora should change the output"
)
if self.unet_kwargs is not None:
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
else:
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.5}
).images
output_lora_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.5}
).images
self.assertTrue(
not np.allclose(output_lora, output_lora_scale, atol=1e-3, rtol=1e-3),
"Lora + scale should change the output",
)
if self.unet_kwargs is not None:
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
else:
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), joint_attention_kwargs={"scale": 0.0}
).images
output_lora_0_scale = pipe(
**inputs, generator=torch.manual_seed(0), cross_attention_kwargs={"scale": 0.0}
).images
self.assertTrue(
np.allclose(output_no_lora, output_lora_0_scale, atol=1e-3, rtol=1e-3),
"Lora + 0 scale should lead to same result as no LoRA",
......@@ -769,38 +627,28 @@ class PeftLoraLoaderMixinTests:
"The scaling parameter has not been correctly restored!",
)
def test_simple_inference_with_text_lora_denoiser_fused(self):
def test_simple_inference_with_text_lora_unet_fused(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -809,10 +657,9 @@ class PeftLoraLoaderMixinTests:
pipe.fuse_lora()
# Fusing should still keep the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
)
......@@ -822,37 +669,27 @@ class PeftLoraLoaderMixinTests:
np.allclose(ouput_fused, output_no_lora, atol=1e-3, rtol=1e-3), "Fused lora should change the output"
)
def test_simple_inference_with_text_denoiser_lora_unloaded(self):
def test_simple_inference_with_text_unet_lora_unloaded(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -863,12 +700,9 @@ class PeftLoraLoaderMixinTests:
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly unloaded in text encoder"
)
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertFalse(
check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly unloaded in denoiser"
)
self.assertFalse(check_if_lora_correctly_set(pipe.unet), "Lora not correctly unloaded in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertFalse(
check_if_lora_correctly_set(pipe.text_encoder_2),
"Lora not correctly unloaded in text encoder 2",
......@@ -880,34 +714,25 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output",
)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
def test_simple_inference_with_text_unet_lora_unfused(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -922,10 +747,9 @@ class PeftLoraLoaderMixinTests:
output_unfused_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
# unloading should remove the LoRA layers
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Unfuse should still keep LoRA layers")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Unfuse should still keep LoRA layers")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Unfuse should still keep LoRA layers")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Unfuse should still keep LoRA layers"
)
......@@ -936,18 +760,13 @@ class PeftLoraLoaderMixinTests:
"Fused lora should change the output",
)
def test_simple_inference_with_text_denoiser_multi_adapter(self):
def test_simple_inference_with_text_unet_multi_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -958,20 +777,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
......@@ -1014,21 +826,13 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_denoiser_block_scale(self):
def test_simple_inference_with_text_unet_block_scale(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
one adapter and set differnt weights for different blocks (i.e. block lora)
"""
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
return
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1037,16 +841,12 @@ class PeftLoraLoaderMixinTests:
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -1081,21 +881,13 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self):
def test_simple_inference_with_text_unet_multi_adapter_block_lora(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set differnt weights for different blocks (i.e. block lora)
"""
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
return
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1106,20 +898,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
......@@ -1168,10 +953,8 @@ class PeftLoraLoaderMixinTests:
with self.assertRaises(ValueError):
pipe.set_adapters(["adapter-1", "adapter-2"], [scales_1])
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
def test_simple_inference_with_text_unet_block_scale_for_all_dict_options(self):
"""Tests that any valid combination of lora block scales can be used in pipe.set_adapter"""
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
return
def updown_options(blocks_with_tf, layers_per_block, value):
"""
......@@ -1236,19 +1019,16 @@ class PeftLoraLoaderMixinTests:
return opts
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(self.scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(self.scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
for scale_dict in all_possible_dict_opts(pipe.unet, value=1234):
......@@ -1258,18 +1038,13 @@ class PeftLoraLoaderMixinTests:
pipe.set_adapters("adapter-1", scale_dict) # test will fail if this line throws an error
def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self):
def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set/delete them
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1280,20 +1055,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
......@@ -1345,14 +1113,8 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
pipe.set_adapters(["adapter-1", "adapter-2"])
pipe.delete_adapters(["adapter-1", "adapter-2"])
......@@ -1364,18 +1126,13 @@ class PeftLoraLoaderMixinTests:
"output with no lora and output with lora disabled should give same results",
)
def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self):
def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
"""
Tests a simple inference with lora attached to text encoder and unet, attaches
multiple adapters and set them
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1386,20 +1143,13 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
......@@ -1452,13 +1202,8 @@ class PeftLoraLoaderMixinTests:
@skip_mps
def test_lora_fuse_nan(self):
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1466,23 +1211,16 @@ class PeftLoraLoaderMixinTests:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
if self.unet_kwargs:
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A[
"adapter-1"
].weight += float("inf")
else:
pipe.transformer.transformer_blocks[0].attn.to_q.lora_A["adapter-1"].weight += float("inf")
pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_A["adapter-1"].weight += float(
"inf"
)
# with `safe_fusing=True` we should see an Error
with self.assertRaises(ValueError):
......@@ -1500,32 +1238,21 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-1"])
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
adapter_names = pipe.get_active_adapters()
self.assertListEqual(adapter_names, ["adapter-2"])
......@@ -1538,108 +1265,65 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where we attach multiple adapters and check if the results
are the expected results
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
adapter_names = pipe.get_list_adapters()
dicts_to_be_checked = {"text_encoder": ["adapter-1"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1"]})
self.assertDictEqual(adapter_names, dicts_to_be_checked)
self.assertDictEqual(adapter_names, {"text_encoder": ["adapter-1"], "unet": ["adapter-1"]})
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
adapter_names = pipe.get_list_adapters()
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual(adapter_names, dicts_to_be_checked)
self.assertDictEqual(
adapter_names, {"text_encoder": ["adapter-1", "adapter-2"], "unet": ["adapter-1", "adapter-2"]}
)
pipe.set_adapters(["adapter-1", "adapter-2"])
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2"]})
self.assertDictEqual(
pipe.get_list_adapters(),
dicts_to_be_checked,
{"unet": ["adapter-1", "adapter-2"], "text_encoder": ["adapter-1", "adapter-2"]},
)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-3")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-3")
dicts_to_be_checked = {"text_encoder": ["adapter-1", "adapter-2"]}
if self.unet_kwargs is not None:
dicts_to_be_checked.update({"unet": ["adapter-1", "adapter-2", "adapter-3"]})
else:
dicts_to_be_checked.update({"transformer": ["adapter-1", "adapter-2", "adapter-3"]})
self.assertDictEqual(pipe.get_list_adapters(), dicts_to_be_checked)
pipe.unet.add_adapter(unet_lora_config, "adapter-3")
self.assertDictEqual(
pipe.get_list_adapters(),
{"unet": ["adapter-1", "adapter-2", "adapter-3"], "text_encoder": ["adapter-1", "adapter-2"]},
)
@require_peft_version_greater(peft_version="0.6.2")
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
def test_simple_inference_with_text_lora_unet_fused_multi(self):
"""
Tests a simple inference with lora attached into text encoder + fuses the lora weights into base model
and makes sure it works as expected - with unet and multi-adapter case
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-1")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
pipe.unet.add_adapter(unet_lora_config, "adapter-1")
# Attach a second adapter
pipe.text_encoder.add_adapter(text_lora_config, "adapter-2")
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config, "adapter-2")
else:
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-2")
pipe.unet.add_adapter(unet_lora_config, "adapter-2")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-1")
pipe.text_encoder_2.add_adapter(text_lora_config, "adapter-2")
self.assertTrue(
......@@ -1675,35 +1359,23 @@ class PeftLoraLoaderMixinTests:
@require_peft_version_greater(peft_version="0.9.0")
def test_simple_inference_with_dora(self):
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, use_dora=True
)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls, use_dora=True)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
shape_to_be_checked = (1, 64, 64, 3) if self.unet_kwargs is not None else (1, 32, 32, 3)
self.assertTrue(output_no_dora_lora.shape == shape_to_be_checked)
self.assertTrue(output_no_dora_lora.shape == (1, 64, 64, 3))
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -1717,34 +1389,25 @@ class PeftLoraLoaderMixinTests:
)
@unittest.skip("This is failing for now - need to investigate")
def test_simple_inference_with_text_denoiser_lora_unfused_torch_compile(self):
def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
"""
Tests a simple inference with lora attached to text encoder and unet, then unloads the lora weights
and makes sure it works as expected
"""
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(scheduler_cls)
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe.text_encoder.add_adapter(text_lora_config)
if self.unet_kwargs is not None:
pipe.unet.add_adapter(denoiser_lora_config)
else:
pipe.transformer.add_adapter(denoiser_lora_config)
pipe.unet.add_adapter(unet_lora_config)
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
denoiser_to_checked = pipe.unet if self.unet_kwargs is not None else pipe.transformer
self.assertTrue(check_if_lora_correctly_set(denoiser_to_checked), "Lora not correctly set in denoiser")
self.assertTrue(check_if_lora_correctly_set(pipe.unet), "Lora not correctly set in Unet")
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2.add_adapter(text_lora_config)
self.assertTrue(
check_if_lora_correctly_set(pipe.text_encoder_2), "Lora not correctly set in text encoder 2"
......@@ -1753,27 +1416,19 @@ class PeftLoraLoaderMixinTests:
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.text_encoder = torch.compile(pipe.text_encoder, mode="reduce-overhead", fullgraph=True)
if self.has_two_text_encoders or self.has_three_text_encoders:
if self.has_two_text_encoders:
pipe.text_encoder_2 = torch.compile(pipe.text_encoder_2, mode="reduce-overhead", fullgraph=True)
# Just makes sure it works..
_ = pipe(**inputs, generator=torch.manual_seed(0)).images
def test_modify_padding_mode(self):
if self.pipeline_class.__name__ == "StableDiffusion3Pipeline":
return
def set_pad_mode(network, mode="circular"):
for _, module in network.named_modules():
if isinstance(module, torch.nn.Conv2d):
module.padding_mode = mode
scheduler_classes = (
[FlowMatchEulerDiscreteScheduler]
if self.has_three_text_encoders and self.transformer_kwargs
else [DDIMScheduler, LCMScheduler]
)
for scheduler_cls in scheduler_classes:
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment