Unverified Commit 6a376cee authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] remove unnecessary components from lora peft test suite (#6401)

remove unnecessary components from lora peft suite/
parent 9f283b01
...@@ -22,7 +22,6 @@ import unittest ...@@ -22,7 +22,6 @@ import unittest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard from huggingface_hub.repocard import RepoCard
from packaging import version from packaging import version
...@@ -41,8 +40,6 @@ from diffusers import ( ...@@ -41,8 +40,6 @@ from diffusers import (
StableDiffusionXLPipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
...@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2): ...@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
return models_are_equal return models_are_equal
def create_unet_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
@require_peft_backend @require_peft_backend
class PeftLoraLoaderMixinTests: class PeftLoraLoaderMixinTests:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" torch_device = "cuda" if torch.cuda.is_available() else "cpu"
...@@ -140,8 +115,6 @@ class PeftLoraLoaderMixinTests: ...@@ -140,8 +115,6 @@ class PeftLoraLoaderMixinTests:
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
) )
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
if self.has_two_text_encoders: if self.has_two_text_encoders:
pipeline_components = { pipeline_components = {
"unet": unet, "unet": unet,
...@@ -165,11 +138,8 @@ class PeftLoraLoaderMixinTests: ...@@ -165,11 +138,8 @@ class PeftLoraLoaderMixinTests:
"feature_extractor": None, "feature_extractor": None,
"image_encoder": None, "image_encoder": None,
} }
lora_components = {
"unet_lora_layers": unet_lora_layers, return pipeline_components, text_lora_config, unet_lora_config
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components, text_lora_config, unet_lora_config
def get_dummy_inputs(self, with_generator=True): def get_dummy_inputs(self, with_generator=True):
batch_size = 1 batch_size = 1
...@@ -216,7 +186,7 @@ class PeftLoraLoaderMixinTests: ...@@ -216,7 +186,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected Tests a simple inference and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -231,7 +201,7 @@ class PeftLoraLoaderMixinTests: ...@@ -231,7 +201,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -262,7 +232,7 @@ class PeftLoraLoaderMixinTests: ...@@ -262,7 +232,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -309,7 +279,7 @@ class PeftLoraLoaderMixinTests: ...@@ -309,7 +279,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -351,7 +321,7 @@ class PeftLoraLoaderMixinTests: ...@@ -351,7 +321,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -394,7 +364,7 @@ class PeftLoraLoaderMixinTests: ...@@ -394,7 +364,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA. Tests a simple usecase where users could use saving utilities for LoRA.
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -459,7 +429,7 @@ class PeftLoraLoaderMixinTests: ...@@ -459,7 +429,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -510,7 +480,7 @@ class PeftLoraLoaderMixinTests: ...@@ -510,7 +480,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -583,7 +553,7 @@ class PeftLoraLoaderMixinTests: ...@@ -583,7 +553,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -637,7 +607,7 @@ class PeftLoraLoaderMixinTests: ...@@ -637,7 +607,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and makes sure it works as expected - with unet
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -683,7 +653,7 @@ class PeftLoraLoaderMixinTests: ...@@ -683,7 +653,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -730,7 +700,7 @@ class PeftLoraLoaderMixinTests: ...@@ -730,7 +700,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -780,7 +750,7 @@ class PeftLoraLoaderMixinTests: ...@@ -780,7 +750,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them multiple adapters and set them
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -848,7 +818,7 @@ class PeftLoraLoaderMixinTests: ...@@ -848,7 +818,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set/delete them multiple adapters and set/delete them
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -938,7 +908,7 @@ class PeftLoraLoaderMixinTests: ...@@ -938,7 +908,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them multiple adapters and set them
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1010,7 +980,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1010,7 +980,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_fuse_nan(self): def test_lora_fuse_nan(self):
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1048,7 +1018,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1048,7 +1018,7 @@ class PeftLoraLoaderMixinTests:
are the expected results are the expected results
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1075,7 +1045,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1075,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
are the expected results are the expected results
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1113,7 +1083,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1113,7 +1083,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and multi-adapter case and makes sure it works as expected - with unet and multi-adapter case
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
...@@ -1175,7 +1145,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1175,7 +1145,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected and makes sure it works as expected
""" """
for scheduler_cls in [DDIMScheduler, LCMScheduler]: for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components) pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device) pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment