Unverified Commit 8eb17315 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] get rid of the legacy lora remnants and make our codebase lighter (#8623)

* get rid of the legacy lora remnants and make our codebase lighter

* fix depcrecated lora argument

* fix

* empty commit to trigger ci

* remove print

* empty
parent c71c19c5
......@@ -75,7 +75,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
......@@ -18,14 +18,10 @@ import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models import StableCascadeUNet
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
enable_full_determinism,
......@@ -49,19 +45,6 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
def create_prior_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=unet.config.c,
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableCascadePriorPipeline
params = ["prompt"]
......@@ -240,19 +223,12 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
lora_components = {
"prior_lora_layers": prior_lora_layers,
"prior_lora_attn_procs": prior_lora_attn_procs,
}
return prior, prior_lora_config, lora_components
return prior, prior_lora_config
@require_peft_backend
@unittest.skip(reason="no lora support for now")
def test_inference_with_prior_lora(self):
_, prior_lora_config, _ = self.get_lora_components()
_, prior_lora_config = self.get_lora_components()
device = "cpu"
components = self.get_dummy_components()
......
......@@ -17,16 +17,9 @@ import unittest
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
)
from diffusers.pipelines.wuerstchen import WuerstchenPrior
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
......@@ -42,19 +35,6 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
def create_prior_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=unet.config.c,
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WuerstchenPriorPipeline
params = ["prompt"]
......@@ -262,18 +242,11 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior)
lora_components = {
"prior_lora_layers": prior_lora_layers,
"prior_lora_attn_procs": prior_lora_attn_procs,
}
return prior, prior_lora_config, lora_components
return prior, prior_lora_config
@require_peft_backend
def test_inference_with_prior_lora(self):
_, prior_lora_config, _ = self.get_lora_components()
_, prior_lora_config = self.get_lora_components()
device = "cpu"
components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment