"docs/vscode:/vscode.git/clone" did not exist on "1f36bd4cf24d221e61cf2609b7c6170e955222bf"
Unverified Commit 9d7c08f9 authored by Andy's avatar Andy Committed by GitHub
Browse files

[WIP] implement rest of the test cases (LoRA tests) (#2824)



* inital commit for lora test cases

* help a bit with lora for 3d

* fixed lora tests

* replaced redundant code

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent dc277501
...@@ -251,7 +251,9 @@ class UNetMidBlock3DCrossAttn(nn.Module): ...@@ -251,7 +251,9 @@ class UNetMidBlock3DCrossAttn(nn.Module):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample ).sample
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
hidden_states = resnet(hidden_states, temb) hidden_states = resnet(hidden_states, temb)
hidden_states = temp_conv(hidden_states, num_frames=num_frames) hidden_states = temp_conv(hidden_states, num_frames=num_frames)
...@@ -376,7 +378,9 @@ class CrossAttnDownBlock3D(nn.Module): ...@@ -376,7 +378,9 @@ class CrossAttnDownBlock3D(nn.Module):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample ).sample
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
output_states += (hidden_states,) output_states += (hidden_states,)
...@@ -587,7 +591,9 @@ class CrossAttnUpBlock3D(nn.Module): ...@@ -587,7 +591,9 @@ class CrossAttnUpBlock3D(nn.Module):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
).sample ).sample
hidden_states = temp_attn(hidden_states, num_frames=num_frames).sample hidden_states = temp_attn(
hidden_states, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
......
...@@ -20,6 +20,7 @@ import torch.nn as nn ...@@ -20,6 +20,7 @@ import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, AttnProcessor from .attention_processor import AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
...@@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput): ...@@ -50,7 +51,7 @@ class UNet3DConditionOutput(BaseOutput):
sample: torch.FloatTensor sample: torch.FloatTensor
class UNet3DConditionModel(ModelMixin, ConfigMixin): class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r""" r"""
UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep UNet3DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output. and returns sample shaped output.
...@@ -465,7 +466,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): ...@@ -465,7 +466,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin):
sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
sample = self.conv_in(sample) sample = self.conv_in(sample)
sample = self.transformer_in(sample, num_frames=num_frames).sample sample = self.transformer_in(
sample, num_frames=num_frames, cross_attention_kwargs=cross_attention_kwargs
).sample
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
......
...@@ -41,7 +41,7 @@ logger = logging.get_logger(__name__) ...@@ -41,7 +41,7 @@ logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
def create_lora_layers(model): def create_lora_layers(model, mock_weights: bool = True):
lora_attn_procs = {} lora_attn_procs = {}
for name in model.attn_processors.keys(): for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
...@@ -57,12 +57,13 @@ def create_lora_layers(model): ...@@ -57,12 +57,13 @@ def create_lora_layers(model):
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device) lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights if mock_weights:
with torch.no_grad(): # add 1 to weights to mock trained weights
lora_attn_procs[name].to_q_lora.up.weight += 1 with torch.no_grad():
lora_attn_procs[name].to_k_lora.up.weight += 1 lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1 lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1 lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
return lora_attn_procs return lora_attn_procs
...@@ -378,26 +379,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -378,26 +379,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
sample1 = model(**inputs_dict).sample sample1 = model(**inputs_dict).sample
lora_attn_procs = {} lora_attn_procs = create_lora_layers(model)
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
# make sure we can set a list of attention processors # make sure we can set a list of attention processors
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
...@@ -465,28 +447,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -465,28 +447,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
old_sample = model(**inputs_dict).sample old_sample = model(**inputs_dict).sample
lora_attn_procs = {} lora_attn_procs = create_lora_layers(model)
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
with torch.no_grad(): with torch.no_grad():
...@@ -518,21 +479,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -518,21 +479,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
lora_attn_procs = {} lora_attn_procs = create_lora_layers(model, mock_weights=False)
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename # Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
...@@ -553,21 +500,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -553,21 +500,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model = self.model_class(**init_dict) model = self.model_class(**init_dict)
model.to(torch_device) model.to(torch_device)
lora_attn_procs = {} lora_attn_procs = create_lora_layers(model, mock_weights=False)
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename # Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
......
...@@ -13,13 +13,15 @@ ...@@ -13,13 +13,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import tempfile
import unittest import unittest
import numpy as np import numpy as np
import torch import torch
from diffusers.models import ModelMixin, UNet3DConditionModel from diffusers.models import ModelMixin, UNet3DConditionModel
from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.models.attention_processor import AttnProcessor, LoRAAttnProcessor
from diffusers.utils import ( from diffusers.utils import (
floats_tensor, floats_tensor,
logging, logging,
...@@ -35,10 +37,13 @@ logger = logging.get_logger(__name__) ...@@ -35,10 +37,13 @@ logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
def create_lora_layers(model): def create_lora_layers(model, mock_weights: bool = True):
lora_attn_procs = {} lora_attn_procs = {}
for name in model.attn_processors.keys(): for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim has_cross_attention = name.endswith("attn2.processor") and not (
name.startswith("transformer_in") or "temp_attentions" in name.split(".")
)
cross_attention_dim = model.config.cross_attention_dim if has_cross_attention else None
if name.startswith("mid_block"): if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1] hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"): elif name.startswith("up_blocks"):
...@@ -47,16 +52,20 @@ def create_lora_layers(model): ...@@ -47,16 +52,20 @@ def create_lora_layers(model):
elif name.startswith("down_blocks"): elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id] hidden_size = model.config.block_out_channels[block_id]
elif name.startswith("transformer_in"):
# Note that the `8 * ...` comes from: https://github.com/huggingface/diffusers/blob/7139f0e874f10b2463caa8cbd585762a309d12d6/src/diffusers/models/unet_3d_condition.py#L148
hidden_size = 8 * model.config.attention_head_dim
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device) lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights if mock_weights:
with torch.no_grad(): # add 1 to weights to mock trained weights
lora_attn_procs[name].to_q_lora.up.weight += 1 with torch.no_grad():
lora_attn_procs[name].to_k_lora.up.weight += 1 lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1 lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1 lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
return lora_attn_procs return lora_attn_procs
...@@ -190,23 +199,173 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase): ...@@ -190,23 +199,173 @@ class UNet3DConditionModelTests(ModelTesterMixin, unittest.TestCase):
output = model(**inputs_dict) output = model(**inputs_dict)
assert output is not None assert output is not None
# (`attn_processors`) needs to be implemented in this model for this test. def test_lora_processors(self):
# def test_lora_processors(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample1 = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
# make sure we can set a list of attention processors
model.set_attn_processor(lora_attn_procs)
model.to(torch_device)
# test that attn processors can be set to itself
model.set_attn_processor(model.attn_processors)
with torch.no_grad():
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample1 - sample2).abs().max() < 1e-4
assert (sample3 - sample4).abs().max() < 1e-4
# sample 2 and sample 3 should be different
assert (sample2 - sample3).abs().max() > 1e-4
def test_lora_save_load(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_load_safetensors(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
def test_lora_save_safetensors_load_torch(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
lora_attn_procs = create_lora_layers(model, mock_weights=False)
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
def test_lora_save_torch_force_load_safetensors_error(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = 8
# (`attn_processors`) needs to be implemented in this model for this test. torch.manual_seed(0)
# def test_lora_save_load(self): model = self.model_class(**init_dict)
model.to(torch_device)
# (`attn_processors`) needs to be implemented for this test in the model. lora_attn_procs = create_lora_layers(model, mock_weights=False)
# def test_lora_save_load_safetensors(self): model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
with self.assertRaises(IOError) as e:
new_model.load_attn_procs(tmpdirname, use_safetensors=True)
self.assertIn("Error no file named pytorch_lora_weights.safetensors", str(e.exception))
def test_lora_on_off(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
# (`attn_processors`) needs to be implemented for this test in the model. init_dict["attention_head_dim"] = 8
# def test_lora_save_safetensors_load_torch(self):
# (`attn_processors`) needs to be implemented for this test. torch.manual_seed(0)
# def test_lora_save_torch_force_load_safetensors_error(self): model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = create_lora_layers(model)
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
model.set_attn_processor(AttnProcessor())
with torch.no_grad():
new_sample = model(**inputs_dict).sample
# (`attn_processors`) needs to be added for this test. assert (sample - new_sample).abs().max() < 1e-4
# def test_lora_on_off(self): assert (sample - old_sample).abs().max() < 1e-4
@unittest.skipIf( @unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(), torch_device != "cuda" or not is_xformers_available(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment