"docs/vscode:/vscode.git/clone" did not exist on "afc502bd669fe97001a50d0ce2373d13c011ed44"
Unverified Commit 61f6c547 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] Remove the use of depcrecated loRA functionalities such as `LoRAAttnProcessor` (#6369)

* start deprecating loraattn.

* fix

* wrap into unet_lora_state_dict

* utilize text_encoder_lora_params

* utilize text_encoder_attn_modules

* debug

* debug

* remove print

* don't use text encoder for test_stable_diffusion_lora

* load the procs.

* set_default_attn_processor

* fix: set_default_attn_processor call.

* fix: lora_components[unet_lora_params]

* checking for 3d.

* 3d.

* more fixes.

* debug

* debug

* debug

* debug

* more debug

* more debug

* more debug

* more debug

* more debug

* more debug

* hack.

* remove comments and prep for a PR.

* appropriate set_lora_weights()

* fix

* fix: test_unload_lora_sd

* fix: test_unload_lora_sd

* use dfault attebtion processors.

* debu

* debug nan

* debug nan

* debug nan

* use NaN instead of inf

* remove comments.

* fix: test_text_encoder_lora_state_dict_unchanged

* attention processor default

* default attention processors.

* default

* style
parent 17546020
......@@ -22,7 +22,6 @@ import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub.repocard import RepoCard
from PIL import Image
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
......@@ -41,17 +40,15 @@ from diffusers import (
UNet2DConditionModel,
UNet3DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.loaders import LoraLoaderMixin, StableDiffusionXLLoraLoaderMixin
from diffusers.models.attention_processor import (
Attention,
AttnProcessor,
AttnProcessor2_0,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.lora import LoRALinearLayer
from diffusers.training_utils import unet_lora_state_dict
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import (
deprecate_after_peft_backend,
......@@ -64,118 +61,178 @@ from diffusers.utils.testing_utils import (
)
def create_lora_layers(model, mock_weights: bool = True):
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
def text_encoder_attn_modules(text_encoder):
attn_modules = []
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
attn_modules.append((name, mod))
else:
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
if mock_weights:
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
return attn_modules
return lora_attn_procs
def text_encoder_lora_state_dict(text_encoder):
state_dict = {}
def create_unet_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
for name, module in text_encoder_attn_modules(text_encoder):
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
text_lora_attn_procs = {}
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
for name, module in text_encoder_attn_modules(text_encoder):
if isinstance(module.out_proj, nn.Linear):
out_features = module.out_proj.out_features
elif isinstance(module.out_proj, PatchedLoraProjection):
out_features = module.out_proj.regular_linear_layer.out_features
else:
assert False, module.out_proj.__class__
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
text_lora_attn_procs[name] = lora_attn_processor_class(hidden_size=out_features, cross_attention_dim=None)
return text_lora_attn_procs
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
return state_dict
def create_text_encoder_lora_layers(text_encoder: nn.Module):
text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
return text_encoder_lora_layers
def create_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
unet_lora_parameters = []
def create_lora_3d_layers(model, mock_weights: bool = True):
lora_attn_procs = {}
for name in model.attn_processors.keys():
has_cross_attention = name.endswith("attn2.processor") and not (
name.startswith("transformer_in") or "temp_attentions" in name.split(".")
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features,
out_features=attn_module.to_q.out_features,
rank=rank,
)
)
cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
elif name.startswith("transformer_in"):
# Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148
hidden_size = 8 * model.config.attention_head_dim
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features,
out_features=attn_module.to_k.out_features,
rank=rank,
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features,
out_features=attn_module.to_v.out_features,
rank=rank,
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=rank,
)
)
if mock_weights:
with torch.no_grad():
attn_module.to_q.lora_layer.up.weight += 1
attn_module.to_k.lora_layer.up.weight += 1
attn_module.to_v.lora_layer.up.weight += 1
attn_module.to_out[0].lora_layer.up.weight += 1
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
return unet_lora_parameters, unet_lora_state_dict(unet)
def create_3d_unet_lora_layers(unet: nn.Module, rank=4, mock_weights=True):
for attn_processor_name in unet.attn_processors.keys():
has_cross_attention = attn_processor_name.endswith("attn2.processor") and not (
attn_processor_name.startswith("transformer_in") or "temp_attentions" in attn_processor_name.split(".")
)
cross_attention_dim = unet.config.cross_attention_dim if has_cross_attention else None
if attn_processor_name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif attn_processor_name.startswith("up_blocks"):
block_id = int(attn_processor_name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif attn_processor_name.startswith("down_blocks"):
block_id = int(attn_processor_name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
elif attn_processor_name.startswith("transformer_in"):
# Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148
hidden_size = 8 * unet.config.attention_head_dim
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=min(attn_module.to_q.in_features, hidden_size),
out_features=attn_module.to_q.out_features
if cross_attention_dim is None
else max(attn_module.to_q.out_features, cross_attention_dim),
rank=rank,
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=min(attn_module.to_k.in_features, hidden_size),
out_features=attn_module.to_k.out_features
if cross_attention_dim is None
else max(attn_module.to_k.out_features, cross_attention_dim),
rank=rank,
)
)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=min(attn_module.to_v.in_features, hidden_size),
out_features=attn_module.to_v.out_features
if cross_attention_dim is None
else max(attn_module.to_v.out_features, cross_attention_dim),
rank=rank,
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=min(attn_module.to_out[0].in_features, hidden_size),
out_features=attn_module.to_out[0].out_features
if cross_attention_dim is None
else max(attn_module.to_out[0].out_features, cross_attention_dim),
rank=rank,
)
)
if mock_weights:
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
attn_module.to_q.lora_layer.up.weight += 1
attn_module.to_k.lora_layer.up.weight += 1
attn_module.to_v.lora_layer.up.weight += 1
attn_module.to_out[0].lora_layer.up.weight += 1
return lora_attn_procs
return unet_lora_state_dict(unet)
def set_lora_weights(lora_attn_parameters, randn_weight=False, var=1.0):
with torch.no_grad():
for parameter in lora_attn_parameters:
if randn_weight:
parameter[:] = torch.randn_like(parameter) * var
else:
torch.zero_(parameter)
if not isinstance(lora_attn_parameters, dict):
with torch.no_grad():
for parameter in lora_attn_parameters:
if randn_weight:
parameter[:] = torch.randn_like(parameter) * var
else:
torch.zero_(parameter)
else:
if randn_weight:
modified_state_dict = {k: torch.rand_like(v) * var for k, v in lora_attn_parameters.items()}
else:
modified_state_dict = {k: torch.zeros_like(v) * var for k, v in lora_attn_parameters.items()}
return modified_state_dict
def state_dicts_almost_equal(sd1, sd2):
......@@ -192,6 +249,8 @@ def state_dicts_almost_equal(sd1, sd2):
@deprecate_after_peft_backend
class LoraLoaderMixinTests(unittest.TestCase):
lora_rank = 4
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
......@@ -235,8 +294,13 @@ class LoraLoaderMixinTests(unittest.TestCase):
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
text_encoder_lora_layers = create_text_encoder_lora_layers(text_encoder)
unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank)
text_encoder_lora_params = LoraLoaderMixin._modify_text_encoder(
text_encoder, dtype=torch.float32, rank=self.lora_rank
)
text_encoder_lora_params = set_lora_weights(
text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1
)
pipeline_components = {
"unet": unet,
......@@ -249,9 +313,9 @@ class LoraLoaderMixinTests(unittest.TestCase):
"image_encoder": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"text_encoder_lora_layers": text_encoder_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
"unet_lora_raw_params": unet_lora_raw_params,
"unet_lora_params": unet_lora_params,
"text_encoder_lora_params": text_encoder_lora_params,
}
return pipeline_components, lora_components
......@@ -290,8 +354,8 @@ class LoraLoaderMixinTests(unittest.TestCase):
_, lora_components = self.get_dummy_components()
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -311,71 +375,12 @@ class LoraLoaderMixinTests(unittest.TestCase):
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
@unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda")
def test_stable_diffusion_attn_processors(self):
# disable_full_determinism()
device = "cuda" # ensure determinism for the device-dependent torch.Generator
components, _ = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs()
# run normal sd pipe
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run attention slicing
sd_pipe.enable_attention_slicing()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run vae attention slicing
sd_pipe.enable_vae_slicing()
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run lora attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
@unittest.skipIf(not torch.cuda.is_available() or not is_xformers_available(), reason="xformers requires cuda")
def test_stable_diffusion_set_xformers_attn_processors(self):
# disable_full_determinism()
device = "cuda" # ensure determinism for the device-dependent torch.Generator
components, _ = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs()
# run normal sd pipe
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# run lora xformers attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {
k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim)
for k, v in attn_processors.items()
}
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# enable_full_determinism()
def test_stable_diffusion_lora(self):
components, _ = self.get_dummy_components()
components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.unet.set_default_attn_processor()
# forward 1
_, _, inputs = self.get_dummy_inputs()
......@@ -385,9 +390,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
# set lora layers
lora_attn_procs = create_lora_layers(sd_pipe.unet)
sd_pipe.unet.set_attn_processor(lora_attn_procs)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"])
# forward 2
_, _, inputs = self.get_dummy_inputs()
......@@ -420,8 +423,8 @@ class LoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -434,7 +437,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
def test_lora_save_load_no_safe_serialization(self):
pipeline_components, lora_components = self.get_dummy_components()
unet_lora_attn_procs = lora_components["unet_lora_attn_procs"]
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
......@@ -445,9 +447,13 @@ class LoraLoaderMixinTests(unittest.TestCase):
orig_image_slice = original_images[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
unet = sd_pipe.unet
unet.set_attn_processor(unet_lora_attn_procs)
unet.save_attn_procs(tmpdirname, safe_serialization=False)
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -468,9 +474,18 @@ class LoraLoaderMixinTests(unittest.TestCase):
assert outputs_without_lora.shape == (1, 77, 32)
# monkey patch
params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
set_lora_weights(params, randn_weight=False)
text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
text_encoder_lora_params = set_lora_weights(
text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=False
)
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=None,
text_encoder_lora_layers=text_encoder_lora_params,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.load_lora_weights(tmpdirname)
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
......@@ -480,13 +495,22 @@ class LoraLoaderMixinTests(unittest.TestCase):
outputs_without_lora, outputs_with_lora
), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs"
# create lora_attn_procs with randn up.weights
create_text_encoder_lora_attn_procs(pipe.text_encoder)
# monkey patch
params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
pipeline_components, _ = self.get_dummy_components()
pipe = StableDiffusionPipeline(**pipeline_components)
set_lora_weights(params, randn_weight=True)
text_encoder_lora_params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
text_encoder_lora_params = set_lora_weights(
text_encoder_lora_state_dict(pipe.text_encoder), randn_weight=True, var=0.1
)
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=None,
text_encoder_lora_layers=text_encoder_lora_params,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.load_lora_weights(tmpdirname)
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
......@@ -508,8 +532,15 @@ class LoraLoaderMixinTests(unittest.TestCase):
# monkey patch
params = pipe._modify_text_encoder(pipe.text_encoder, pipe.lora_scale)
set_lora_weights(params, randn_weight=True)
params = set_lora_weights(text_encoder_lora_state_dict(pipe.text_encoder), var=0.1, randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=None,
text_encoder_lora_layers=params,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
pipe.load_lora_weights(tmpdirname)
# inference with lora
outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0]
......@@ -541,8 +572,8 @@ class LoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -587,19 +618,16 @@ class LoraLoaderMixinTests(unittest.TestCase):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe.unet.set_default_attn_processor()
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -677,7 +705,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -691,7 +719,9 @@ class LoraLoaderMixinTests(unittest.TestCase):
@deprecate_after_peft_backend
class SDXInpaintLoraMixinTests(unittest.TestCase):
class SDInpaintLoraMixinTests(unittest.TestCase):
lora_rank = 4
def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True):
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
if output_pil:
......@@ -765,6 +795,14 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
unet_lora_raw_params, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank)
text_encoder_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder(
text_encoder, dtype=torch.float32, rank=self.lora_rank
)
text_encoder_lora_params = set_lora_weights(
text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1
)
components = {
"unet": unet,
"scheduler": scheduler,
......@@ -775,15 +813,21 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
"feature_extractor": None,
"image_encoder": None,
}
return components
lora_components = {
"unet_lora_raw_params": unet_lora_raw_params,
"unet_lora_params": unet_lora_params,
"text_encoder_lora_params": text_encoder_lora_params,
}
return components, lora_components
def test_stable_diffusion_inpaint_lora(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.unet.set_default_attn_processor()
# forward 1
inputs = self.get_dummy_inputs(device)
......@@ -792,9 +836,7 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
image_slice = image[0, -3:, -3:, -1]
# set lora layers
lora_attn_procs = create_lora_layers(sd_pipe.unet)
sd_pipe.unet.set_attn_processor(lora_attn_procs)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.unet.load_attn_procs(lora_components["unet_lora_params"])
# forward 2
inputs = self.get_dummy_inputs(device)
......@@ -814,7 +856,9 @@ class SDXInpaintLoraMixinTests(unittest.TestCase):
@deprecate_after_peft_backend
class SDXLLoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
lora_rank = 4
def get_dummy_components(self, modify_text_encoder=True):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
......@@ -871,9 +915,24 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder)
text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2)
_, unet_lora_params = create_unet_lora_layers(unet, rank=self.lora_rank)
if modify_text_encoder:
text_encoder_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder(
text_encoder, dtype=torch.float32, rank=self.lora_rank
)
text_encoder_lora_params = set_lora_weights(
text_encoder_lora_state_dict(text_encoder), randn_weight=True, var=0.1
)
text_encoder_two_lora_params = StableDiffusionXLLoraLoaderMixin._modify_text_encoder(
text_encoder_2, dtype=torch.float32, rank=self.lora_rank
)
text_encoder_two_lora_params = set_lora_weights(
text_encoder_lora_state_dict(text_encoder_2), randn_weight=True, var=0.1
)
else:
text_encoder_lora_params = None
text_encoder_two_lora_params = None
pipeline_components = {
"unet": unet,
......@@ -887,10 +946,9 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
"feature_extractor": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"text_encoder_one_lora_layers": text_encoder_one_lora_layers,
"text_encoder_two_lora_layers": text_encoder_two_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
"unet_lora_params": unet_lora_params,
"text_encoder_lora_params": text_encoder_lora_params,
"text_encoder_two_lora_params": text_encoder_two_lora_params,
}
return pipeline_components, lora_components
......@@ -929,9 +987,9 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -946,21 +1004,17 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe.unet.set_default_attn_processor()
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -992,9 +1046,9 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
......@@ -1003,7 +1057,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe.unload_lora_weights()
def test_text_encoder_lora_state_dict_unchanged(self):
pipeline_components, lora_components = self.get_dummy_components()
pipeline_components, lora_components = self.get_dummy_components(modify_text_encoder=False)
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
text_encoder_1_sd_keys = sorted(sd_pipe.text_encoder.state_dict().keys())
......@@ -1012,12 +1066,26 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
# Modify the text encoder.
_ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder(
sd_pipe.text_encoder, dtype=torch.float32, rank=self.lora_rank
)
lora_components["text_encoder_lora_params"] = set_lora_weights(
text_encoder_lora_state_dict(sd_pipe.text_encoder), randn_weight=True, var=0.1
)
_ = StableDiffusionXLLoraLoaderMixin._modify_text_encoder(
sd_pipe.text_encoder_2, dtype=torch.float32, rank=self.lora_rank
)
lora_components["text_encoder_two_lora_params"] = set_lora_weights(
text_encoder_lora_state_dict(sd_pipe.text_encoder_2), randn_weight=True, var=0.1
)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
......@@ -1050,9 +1118,9 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1066,19 +1134,12 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1087,7 +1148,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
# corrupt one LoRA weight with `inf` values
with torch.no_grad():
sd_pipe.unet.mid_block.attentions[0].transformer_blocks[0].attn1.to_q.lora_layer.down.weight += float(
"inf"
"NaN"
)
# with `safe_fusing=True` we should see an Error
......@@ -1112,17 +1173,12 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1139,23 +1195,19 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.unet.set_default_attn_processor()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1190,17 +1242,12 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
_ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1229,17 +1276,12 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
_ = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1255,9 +1297,9 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1276,22 +1318,18 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.unet.set_default_attn_processor()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
original_imagee_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1323,23 +1361,19 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
sd_pipe.unet.set_default_attn_processor()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
images_slice = images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1376,17 +1410,12 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True, var=0.1)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True, var=0.1)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
unet_lora_layers=lora_components["unet_lora_params"],
text_encoder_lora_layers=lora_components["text_encoder_lora_params"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_params"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
......@@ -1460,10 +1489,10 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
with torch.no_grad():
sample1 = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
_, lora_params = create_unet_lora_layers(model)
# make sure we can set a list of attention processors
model.set_attn_processor(lora_attn_procs)
model.load_attn_procs(lora_params)
model.to(torch_device)
# test that attn processors can be set to itself
......@@ -1480,120 +1509,6 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
# sample 2 and sample 3 should be different
assert (sample2 - sample3).abs().max() > 1e-4
def test_lora_save_load(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 5e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 5e-4
def test_lora_save_load_safetensors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_safetensors_load_torch(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors")
def test_lora_save_torch_force_load_safetensors_error(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
with self.assertRaises(IOError) as e:
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
def test_lora_on_off(self, expected_max_diff=1e-3):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -1607,8 +1522,8 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
_, lora_params = create_unet_lora_layers(model)
model.load_attn_procs(lora_params)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
......@@ -1637,8 +1552,8 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
_, lora_params = create_unet_lora_layers(model)
model.load_attn_procs(lora_params)
# default
with torch.no_grad():
......@@ -1712,10 +1627,10 @@ class UNet3DConditionModelTests(unittest.TestCase):
with torch.no_grad():
sample1 = model(**inputs_dict).sample
lora_attn_procs = create_lora_3d_layers(model)
unet_lora_params = create_3d_unet_lora_layers(model)
# make sure we can set a list of attention processors
model.set_attn_processor(lora_attn_procs)
model.load_attn_procs(unet_lora_params)
model.to(torch_device)
# test that attn processors can be set to itself
......@@ -1732,172 +1647,6 @@ class UNet3DConditionModelTests(unittest.TestCase):
# sample 2 and sample 3 should be different
assert (sample2 - sample3).abs().max() > 3e-3
def test_lora_save_load(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_3d_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 5e-3
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_load_safetensors(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_3d_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 3e-3
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_safetensors_load_torch(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_3d_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors")
def test_lora_save_torch_force_load_safetensors_error(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_3d_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
with self.assertRaises(IOError) as e:
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
def test_lora_on_off(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_3d_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_default_attn_processor()
with torch.no_grad():
new_sample = model(**inputs_dict).sample
assert (sample - new_sample).abs().max() < 1e-4
assert (sample - old_sample).abs().max() < 3e-3
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_lora_xformers_on_off(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 4
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_3d_layers(model)
model.set_attn_processor(lora_attn_procs)
# default
with torch.no_grad():
sample = model(**inputs_dict).sample
model.enable_xformers_memory_efficient_attention()
on_sample = model(**inputs_dict).sample
model.disable_xformers_memory_efficient_attention()
off_sample = model(**inputs_dict).sample
assert (sample - on_sample).abs().max() < 1e-4
assert (sample - off_sample).abs().max() < 1e-4
@slow
@deprecate_after_peft_backend
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment