Unverified Commit 8eb17315 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] get rid of the legacy lora remnants and make our codebase lighter (#8623)

* get rid of the legacy lora remnants and make our codebase lighter

* fix depcrecated lora argument

* fix

* empty commit to trigger ci

* remove print

* empty
parent c71c19c5
...@@ -75,7 +75,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -75,7 +75,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -18,14 +18,10 @@ import unittest ...@@ -18,14 +18,10 @@ import unittest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline from diffusers import DDPMWuerstchenScheduler, StableCascadePriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models import StableCascadeUNet from diffusers.models import StableCascadeUNet
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
...@@ -49,19 +45,6 @@ from ..test_pipelines_common import PipelineTesterMixin ...@@ -49,19 +45,6 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism() enable_full_determinism()
def create_prior_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=unet.config.c,
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableCascadePriorPipeline pipeline_class = StableCascadePriorPipeline
params = ["prompt"] params = ["prompt"]
...@@ -240,19 +223,12 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase ...@@ -240,19 +223,12 @@ class StableCascadePriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
) )
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior) return prior, prior_lora_config
lora_components = {
"prior_lora_layers": prior_lora_layers,
"prior_lora_attn_procs": prior_lora_attn_procs,
}
return prior, prior_lora_config, lora_components
@require_peft_backend @require_peft_backend
@unittest.skip(reason="no lora support for now") @unittest.skip(reason="no lora support for now")
def test_inference_with_prior_lora(self): def test_inference_with_prior_lora(self):
_, prior_lora_config, _ = self.get_lora_components() _, prior_lora_config = self.get_lora_components()
device = "cpu" device = "cpu"
components = self.get_dummy_components() components = self.get_dummy_components()
......
...@@ -17,16 +17,9 @@ import unittest ...@@ -17,16 +17,9 @@ import unittest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline from diffusers import DDPMWuerstchenScheduler, WuerstchenPriorPipeline
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
)
from diffusers.pipelines.wuerstchen import WuerstchenPrior from diffusers.pipelines.wuerstchen import WuerstchenPrior
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_peft_backend, skip_mps, torch_device
...@@ -42,19 +35,6 @@ from ..test_pipelines_common import PipelineTesterMixin ...@@ -42,19 +35,6 @@ from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism() enable_full_determinism()
def create_prior_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=unet.config.c,
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers
class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = WuerstchenPriorPipeline pipeline_class = WuerstchenPriorPipeline
params = ["prompt"] params = ["prompt"]
...@@ -262,18 +242,11 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -262,18 +242,11 @@ class WuerstchenPriorPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False r=4, lora_alpha=4, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
) )
prior_lora_attn_procs, prior_lora_layers = create_prior_lora_layers(prior) return prior, prior_lora_config
lora_components = {
"prior_lora_layers": prior_lora_layers,
"prior_lora_attn_procs": prior_lora_attn_procs,
}
return prior, prior_lora_config, lora_components
@require_peft_backend @require_peft_backend
def test_inference_with_prior_lora(self): def test_inference_with_prior_lora(self):
_, prior_lora_config, _ = self.get_lora_components() _, prior_lora_config = self.get_lora_components()
device = "cpu" device = "cpu"
components = self.get_dummy_components() components = self.get_dummy_components()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment