"torchvision/csrc/ops/cpu/nms_kernel.cpp" did not exist on "0ebbb0abd0610c8ffe978902c06751f94a2e3197"
Unverified Commit 6a376cee authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] remove unnecessary components from lora peft test suite (#6401)

remove unnecessary components from lora peft suite/
parent 9f283b01
......@@ -22,7 +22,6 @@ import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from packaging import version
......@@ -41,8 +40,6 @@ from diffusers import (
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
from diffusers.utils.testing_utils import (
floats_tensor,
......@@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
return models_are_equal
def create_unet_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
@require_peft_backend
class PeftLoraLoaderMixinTests:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -140,8 +115,6 @@ class PeftLoraLoaderMixinTests:
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
if self.has_two_text_encoders:
pipeline_components = {
"unet": unet,
......@@ -165,11 +138,8 @@ class PeftLoraLoaderMixinTests:
"feature_extractor": None,
"image_encoder": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components, text_lora_config, unet_lora_config
return pipeline_components, text_lora_config, unet_lora_config
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
......@@ -216,7 +186,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -231,7 +201,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -262,7 +232,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -309,7 +279,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -351,7 +321,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -394,7 +364,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -459,7 +429,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -510,7 +480,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -583,7 +553,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -637,7 +607,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -683,7 +653,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -730,7 +700,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -780,7 +750,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -848,7 +818,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set/delete them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -938,7 +908,7 @@ class PeftLoraLoaderMixinTests:
multiple adapters and set them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1010,7 +980,7 @@ class PeftLoraLoaderMixinTests:
def test_lora_fuse_nan(self):
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1048,7 +1018,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1075,7 +1045,7 @@ class PeftLoraLoaderMixinTests:
are the expected results
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1113,7 +1083,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected - with unet and multi-adapter case
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......@@ -1175,7 +1145,7 @@ class PeftLoraLoaderMixinTests:
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment