Unverified Commit 8eb17315 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] get rid of the legacy lora remnants and make our codebase lighter (#8623)

* get rid of the legacy lora remnants and make our codebase lighter

* fix depcrecated lora argument

* fix

* empty commit to trigger ci

* remove print

* empty
parent c71c19c5
...@@ -41,12 +41,6 @@ An attention processor is a class for applying different types of attention mech ...@@ -41,12 +41,6 @@ An attention processor is a class for applying different types of attention mech
## FusedAttnProcessor2_0 ## FusedAttnProcessor2_0
[[autodoc]] models.attention_processor.FusedAttnProcessor2_0 [[autodoc]] models.attention_processor.FusedAttnProcessor2_0
## LoRAAttnAddedKVProcessor
[[autodoc]] models.attention_processor.LoRAAttnAddedKVProcessor
## LoRAXFormersAttnProcessor
[[autodoc]] models.attention_processor.LoRAXFormersAttnProcessor
## SlicedAttnProcessor ## SlicedAttnProcessor
[[autodoc]] models.attention_processor.SlicedAttnProcessor [[autodoc]] models.attention_processor.SlicedAttnProcessor
......
...@@ -24,12 +24,7 @@ from diffusers import DiffusionPipeline, StableDiffusionXLPipeline ...@@ -24,12 +24,7 @@ from diffusers import DiffusionPipeline, StableDiffusionXLPipeline
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
...@@ -1292,12 +1287,7 @@ class SDXLLongPromptWeightingPipeline( ...@@ -1292,12 +1287,7 @@ class SDXLLongPromptWeightingPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
......
...@@ -16,12 +16,7 @@ from diffusers.loaders import ( ...@@ -16,12 +16,7 @@ from diffusers.loaders import (
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
) )
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
...@@ -612,12 +607,7 @@ class DemoFusionSDXLPipeline( ...@@ -612,12 +607,7 @@ class DemoFusionSDXLPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
......
...@@ -46,8 +46,6 @@ from diffusers.models.attention_processor import ( ...@@ -46,8 +46,6 @@ from diffusers.models.attention_processor import (
Attention, Attention,
AttnProcessor2_0, AttnProcessor2_0,
FusedAttnProcessor2_0, FusedAttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
...@@ -1153,8 +1151,6 @@ class StyleAlignedSDXLPipeline( ...@@ -1153,8 +1151,6 @@ class StyleAlignedSDXLPipeline(
( (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
FusedAttnProcessor2_0, FusedAttnProcessor2_0,
), ),
) )
......
...@@ -25,12 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz ...@@ -25,12 +25,7 @@ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokeniz
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel from diffusers.models import AutoencoderKL, ControlNetModel, MultiAdapter, T2IAdapter, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
...@@ -797,12 +792,7 @@ class StableDiffusionXLControlNetAdapterPipeline( ...@@ -797,12 +792,7 @@ class StableDiffusionXLControlNetAdapterPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
......
...@@ -44,12 +44,7 @@ from diffusers.models import ( ...@@ -44,12 +44,7 @@ from diffusers.models import (
T2IAdapter, T2IAdapter,
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.pipeline_utils import StableDiffusionMixin from diffusers.pipelines.pipeline_utils import StableDiffusionMixin
...@@ -1135,12 +1130,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline( ...@@ -1135,12 +1130,7 @@ class StableDiffusionXLControlNetAdapterInpaintPipeline(
self.vae.to(dtype=torch.float32) self.vae.to(dtype=torch.float32)
use_torch_2_0_or_xformers = isinstance( use_torch_2_0_or_xformers = isinstance(
self.vae.decoder.mid_block.attentions[0].processor, self.vae.decoder.mid_block.attentions[0].processor,
( (AttnProcessor2_0, XFormersAttnProcessor),
AttnProcessor2_0,
XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
# to be in float32 which can save lots of memory # to be in float32 which can save lots of memory
......
...@@ -37,8 +37,6 @@ from diffusers.loaders import ( ...@@ -37,8 +37,6 @@ from diffusers.loaders import (
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
...@@ -854,8 +852,6 @@ class StableDiffusionXLDifferentialImg2ImgPipeline( ...@@ -854,8 +852,6 @@ class StableDiffusionXLDifferentialImg2ImgPipeline(
( (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
), ),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
......
...@@ -34,8 +34,6 @@ from diffusers.loaders import ( ...@@ -34,8 +34,6 @@ from diffusers.loaders import (
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
XFormersAttnProcessor, XFormersAttnProcessor,
) )
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
...@@ -662,8 +660,6 @@ class StableDiffusionXLPipelineIpex( ...@@ -662,8 +660,6 @@ class StableDiffusionXLPipelineIpex(
( (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
LoRAXFormersAttnProcessor,
LoRAAttnProcessor2_0,
), ),
) )
# if xformers or torch_2_0 is used attention block does not need # if xformers or torch_2_0 is used attention block does not need
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
import math import math
from importlib import import_module
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import torch import torch
...@@ -24,7 +23,6 @@ from ..image_processor import IPAdapterMaskProcessor ...@@ -24,7 +23,6 @@ from ..image_processor import IPAdapterMaskProcessor
from ..utils import deprecate, logging from ..utils import deprecate, logging
from ..utils.import_utils import is_torch_npu_available, is_xformers_available from ..utils.import_utils import is_torch_npu_available, is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -259,10 +257,6 @@ class Attention(nn.Module): ...@@ -259,10 +257,6 @@ class Attention(nn.Module):
The attention operation to use. Defaults to `None` which uses the default attention operation from The attention operation to use. Defaults to `None` which uses the default attention operation from
`xformers`. `xformers`.
""" """
is_lora = hasattr(self, "processor") and isinstance(
self.processor,
LORA_ATTENTION_PROCESSORS,
)
is_custom_diffusion = hasattr(self, "processor") and isinstance( is_custom_diffusion = hasattr(self, "processor") and isinstance(
self.processor, self.processor,
(CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0), (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor, CustomDiffusionAttnProcessor2_0),
...@@ -274,14 +268,13 @@ class Attention(nn.Module): ...@@ -274,14 +268,13 @@ class Attention(nn.Module):
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
XFormersAttnAddedKVProcessor, XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
), ),
) )
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if is_added_kv_processor and (is_lora or is_custom_diffusion): if is_added_kv_processor and is_custom_diffusion:
raise NotImplementedError( raise NotImplementedError(
f"Memory efficient attention is currently not supported for LoRA or custom diffusion for attention processor type {self.processor}" f"Memory efficient attention is currently not supported for custom diffusion for attention processor type {self.processor}"
) )
if not is_xformers_available(): if not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
...@@ -307,18 +300,7 @@ class Attention(nn.Module): ...@@ -307,18 +300,7 @@ class Attention(nn.Module):
except Exception as e: except Exception as e:
raise e raise e
if is_lora: if is_custom_diffusion:
# TODO (sayakpaul): should we throw a warning if someone wants to use the xformers
# variant when using PT 2.0 now that we have LoRAAttnProcessor2_0?
processor = LoRAXFormersAttnProcessor(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
attention_op=attention_op,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
processor = CustomDiffusionXFormersAttnProcessor( processor = CustomDiffusionXFormersAttnProcessor(
train_kv=self.processor.train_kv, train_kv=self.processor.train_kv,
train_q_out=self.processor.train_q_out, train_q_out=self.processor.train_q_out,
...@@ -341,18 +323,7 @@ class Attention(nn.Module): ...@@ -341,18 +323,7 @@ class Attention(nn.Module):
else: else:
processor = XFormersAttnProcessor(attention_op=attention_op) processor = XFormersAttnProcessor(attention_op=attention_op)
else: else:
if is_lora: if is_custom_diffusion:
attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
processor = attn_processor_class(
hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
rank=self.processor.rank,
)
processor.load_state_dict(self.processor.state_dict())
processor.to(self.processor.to_q_lora.up.weight.device)
elif is_custom_diffusion:
attn_processor_class = ( attn_processor_class = (
CustomDiffusionAttnProcessor2_0 CustomDiffusionAttnProcessor2_0
if hasattr(F, "scaled_dot_product_attention") if hasattr(F, "scaled_dot_product_attention")
...@@ -442,82 +413,6 @@ class Attention(nn.Module): ...@@ -442,82 +413,6 @@ class Attention(nn.Module):
if not return_deprecated_lora: if not return_deprecated_lora:
return self.processor return self.processor
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
# serialization format for LoRA Attention Processors. It should be deleted once the integration
# with PEFT is completed.
is_lora_activated = {
name: module.lora_layer is not None
for name, module in self.named_modules()
if hasattr(module, "lora_layer")
}
# 1. if no layer has a LoRA activated we can return the processor as usual
if not any(is_lora_activated.values()):
return self.processor
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
is_lora_activated.pop("add_k_proj", None)
is_lora_activated.pop("add_v_proj", None)
# 2. else it is not possible that only some layers have LoRA activated
if not all(is_lora_activated.values()):
raise ValueError(
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
)
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
non_lora_processor_cls_name = self.processor.__class__.__name__
lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
hidden_size = self.inner_dim
# now create a LoRA attention processor from the LoRA layers
if lora_processor_cls in [LoRAAttnProcessor, LoRAAttnProcessor2_0, LoRAXFormersAttnProcessor]:
kwargs = {
"cross_attention_dim": self.cross_attention_dim,
"rank": self.to_q.lora_layer.rank,
"network_alpha": self.to_q.lora_layer.network_alpha,
"q_rank": self.to_q.lora_layer.rank,
"q_hidden_size": self.to_q.lora_layer.out_features,
"k_rank": self.to_k.lora_layer.rank,
"k_hidden_size": self.to_k.lora_layer.out_features,
"v_rank": self.to_v.lora_layer.rank,
"v_hidden_size": self.to_v.lora_layer.out_features,
"out_rank": self.to_out[0].lora_layer.rank,
"out_hidden_size": self.to_out[0].lora_layer.out_features,
}
if hasattr(self.processor, "attention_op"):
kwargs["attention_op"] = self.processor.attention_op
lora_processor = lora_processor_cls(hidden_size, **kwargs)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
lora_processor = lora_processor_cls(
hidden_size,
cross_attention_dim=self.add_k_proj.weight.shape[0],
rank=self.to_q.lora_layer.rank,
network_alpha=self.to_q.lora_layer.network_alpha,
)
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
# only save if used
if self.add_k_proj.lora_layer is not None:
lora_processor.add_k_proj_lora.load_state_dict(self.add_k_proj.lora_layer.state_dict())
lora_processor.add_v_proj_lora.load_state_dict(self.add_v_proj.lora_layer.state_dict())
else:
lora_processor.add_k_proj_lora = None
lora_processor.add_v_proj_lora = None
else:
raise ValueError(f"{lora_processor_cls} does not exist.")
return lora_processor
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
...@@ -2238,264 +2133,6 @@ class SpatialNorm(nn.Module): ...@@ -2238,264 +2133,6 @@ class SpatialNorm(nn.Module):
return new_f return new_f
class LoRAAttnProcessor(nn.Module):
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
deprecate("LoRAAttnProcessor", "0.30.0", deprecation_message, standard_warn=False)
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnProcessor2_0(nn.Module):
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
**kwargs,
):
deprecation_message = "Using LoRAAttnProcessor is deprecated. Please use the PEFT backend for all things LoRA. You can install PEFT by running `pip install peft`."
deprecate("LoRAAttnProcessor2_0", "0.30.0", deprecation_message, standard_warn=False)
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnProcessor2_0()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAXFormersAttnProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism with memory efficient attention using xFormers.
Args:
hidden_size (`int`, *optional*):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
operator.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
def __init__(
self,
hidden_size: int,
cross_attention_dim: int,
rank: int = 4,
attention_op: Optional[Callable] = None,
network_alpha: Optional[int] = None,
**kwargs,
):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op
q_rank = kwargs.pop("q_rank", None)
q_hidden_size = kwargs.pop("q_hidden_size", None)
q_rank = q_rank if q_rank is not None else rank
q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
v_rank = kwargs.pop("v_rank", None)
v_hidden_size = kwargs.pop("v_hidden_size", None)
v_rank = v_rank if v_rank is not None else rank
v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
out_rank = kwargs.pop("out_rank", None)
out_hidden_size = kwargs.pop("out_hidden_size", None)
out_rank = out_rank if out_rank is not None else rank
out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = XFormersAttnProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnAddedKVProcessor(nn.Module):
r"""
Processor for implementing the LoRA attention mechanism with extra learnable key and value matrices for the text
encoder.
Args:
hidden_size (`int`, *optional*):
The hidden size of the attention layer.
cross_attention_dim (`int`, *optional*, defaults to `None`):
The number of channels in the `encoder_hidden_states`.
rank (`int`, defaults to 4):
The dimension of the LoRA update matrices.
network_alpha (`int`, *optional*):
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
kwargs (`dict`):
Additional keyword arguments to pass to the `LoRALinearLayer` layers.
"""
def __init__(
self,
hidden_size: int,
cross_attention_dim: Optional[int] = None,
rank: int = 4,
network_alpha: Optional[int] = None,
):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
self_cls_name = self.__class__.__name__
deprecate(
self_cls_name,
"0.26.0",
(
f"Make sure use {self_cls_name[4:]} instead by setting"
"LoRA layers to `self.{to_q,to_k,to_v,add_k_proj,add_v_proj,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
" `LoraLoaderMixin.load_lora_weights`"
),
)
attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
attn._modules.pop("processor")
attn.processor = AttnAddedKVProcessor()
return attn.processor(attn, hidden_states, **kwargs)
class IPAdapterAttnProcessor(nn.Module): class IPAdapterAttnProcessor(nn.Module):
r""" r"""
Attention processor for Multiple IP-Adapters. Attention processor for Multiple IP-Adapters.
...@@ -2924,19 +2561,11 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2924,19 +2561,11 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states return hidden_states
LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
)
ADDED_KV_ATTENTION_PROCESSORS = ( ADDED_KV_ATTENTION_PROCESSORS = (
AttnAddedKVProcessor, AttnAddedKVProcessor,
SlicedAttnAddedKVProcessor, SlicedAttnAddedKVProcessor,
AttnAddedKVProcessor2_0, AttnAddedKVProcessor2_0,
XFormersAttnAddedKVProcessor, XFormersAttnAddedKVProcessor,
LoRAAttnAddedKVProcessor,
) )
CROSS_ATTENTION_PROCESSORS = ( CROSS_ATTENTION_PROCESSORS = (
...@@ -2944,9 +2573,6 @@ CROSS_ATTENTION_PROCESSORS = ( ...@@ -2944,9 +2573,6 @@ CROSS_ATTENTION_PROCESSORS = (
AttnProcessor2_0, AttnProcessor2_0,
XFormersAttnProcessor, XFormersAttnProcessor,
SlicedAttnProcessor, SlicedAttnProcessor,
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
) )
...@@ -2964,9 +2590,4 @@ AttentionProcessor = Union[ ...@@ -2964,9 +2590,4 @@ AttentionProcessor = Union[
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
CustomDiffusionXFormersAttnProcessor, CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0, CustomDiffusionAttnProcessor2_0,
# deprecated
LoRAAttnProcessor,
LoRAAttnProcessor2_0,
LoRAXFormersAttnProcessor,
LoRAAttnAddedKVProcessor,
] ]
...@@ -175,7 +175,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -175,7 +175,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -253,7 +253,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin): ...@@ -253,7 +253,7 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin): ...@@ -211,7 +211,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -530,7 +530,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -530,7 +530,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -149,7 +149,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal ...@@ -149,7 +149,7 @@ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginal
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -880,7 +880,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin): ...@@ -880,7 +880,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -373,7 +373,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -373,7 +373,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -179,7 +179,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef ...@@ -179,7 +179,7 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -152,7 +152,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi ...@@ -152,7 +152,7 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -705,7 +705,7 @@ class UNet2DConditionModel( ...@@ -705,7 +705,7 @@ class UNet2DConditionModel(
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
...@@ -301,7 +301,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -301,7 +301,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment