Unverified Commit 531e7191 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] use the PyTorch classes wherever needed and start depcrecation cycles (#7204)

* fix PyTorch classes and start deprecsation cycles.

* remove args crafting for accommodating scale.

* remove scale check in feedforward.

* assert against nn.Linear and not CompatibleLinear.

* remove conv_cls and lineaR_cls.

* remove scale

* 👋

 scale.

* fix: unet2dcondition

* fix attention.py

* fix: attention.py again

* fix: unet_2d_blocks.

* fix-copies.

* more fixes.

* fix: resnet.py

* more fixes

* fix i2vgenxl unet.

* depcrecate scale gently.

* fix-copies

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* quality

* throw warning when scale is passed to the the BasicTransformerBlock class.

* remove scale from signature.

* cross_attention_kwargs, very nice catch by Yiyi

* fix: logger.warn

* make deprecation message clearer.

* address final comments.

* maintain same depcrecation message and also add it to activations.

* address yiyi

* fix copies

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* more depcrecation

* fix-copies

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 4fbd310f
...@@ -17,8 +17,7 @@ import torch ...@@ -17,8 +17,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND from ..utils import deprecate
from .lora import LoRACompatibleLinear
ACTIVATION_FUNCTIONS = { ACTIVATION_FUNCTIONS = {
...@@ -87,9 +86,7 @@ class GEGLU(nn.Module): ...@@ -87,9 +86,7 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
super().__init__() super().__init__()
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.proj = linear_cls(dim_in, dim_out * 2, bias=bias)
def gelu(self, gate: torch.Tensor) -> torch.Tensor: def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps": if gate.device.type != "mps":
...@@ -97,9 +94,12 @@ class GEGLU(nn.Module): ...@@ -97,9 +94,12 @@ class GEGLU(nn.Module):
# mps: gelu is not implemented for float16 # mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states, *args, **kwargs):
args = () if USE_PEFT_BACKEND else (scale,) if len(args) > 0 or kwargs.get("scale", None) is not None:
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate) return hidden_states * self.gelu(gate)
......
...@@ -17,18 +17,18 @@ import torch ...@@ -17,18 +17,18 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND from ..utils import deprecate, logging
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .activations import GEGLU, GELU, ApproximateGELU from .activations import GEGLU, GELU, ApproximateGELU
from .attention_processor import Attention from .attention_processor import Attention
from .embeddings import SinusoidalPositionalEmbedding from .embeddings import SinusoidalPositionalEmbedding
from .lora import LoRACompatibleLinear
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
def _chunked_feed_forward( logger = logging.get_logger(__name__)
ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
):
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int):
# "feed_forward_chunk_size" can be used to save memory # "feed_forward_chunk_size" can be used to save memory
if hidden_states.shape[chunk_dim] % chunk_size != 0: if hidden_states.shape[chunk_dim] % chunk_size != 0:
raise ValueError( raise ValueError(
...@@ -36,18 +36,10 @@ def _chunked_feed_forward( ...@@ -36,18 +36,10 @@ def _chunked_feed_forward(
) )
num_chunks = hidden_states.shape[chunk_dim] // chunk_size num_chunks = hidden_states.shape[chunk_dim] // chunk_size
if lora_scale is None: ff_output = torch.cat(
ff_output = torch.cat( [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
[ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], dim=chunk_dim,
dim=chunk_dim, )
)
else:
# TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
ff_output = torch.cat(
[ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
dim=chunk_dim,
)
return ff_output return ff_output
...@@ -299,6 +291,10 @@ class BasicTransformerBlock(nn.Module): ...@@ -299,6 +291,10 @@ class BasicTransformerBlock(nn.Module):
class_labels: Optional[torch.LongTensor] = None, class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
# Notice that normalization is always applied before the real computation in the following blocks. # Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention # 0. Self-Attention
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -326,10 +322,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -326,10 +322,7 @@ class BasicTransformerBlock(nn.Module):
if self.pos_embed is not None: if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states) norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale. # 1. Prepare GLIGEN inputs
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None) gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
...@@ -348,7 +341,7 @@ class BasicTransformerBlock(nn.Module): ...@@ -348,7 +341,7 @@ class BasicTransformerBlock(nn.Module):
if hidden_states.ndim == 4: if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1) hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control # 1.2 GLIGEN Control
if gligen_kwargs is not None: if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
...@@ -394,11 +387,9 @@ class BasicTransformerBlock(nn.Module): ...@@ -394,11 +387,9 @@ class BasicTransformerBlock(nn.Module):
if self._chunk_size is not None: if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory # "feed_forward_chunk_size" can be used to save memory
ff_output = _chunked_feed_forward( ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
)
else: else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale) ff_output = self.ff(norm_hidden_states)
if self.norm_type == "ada_norm_zero": if self.norm_type == "ada_norm_zero":
ff_output = gate_mlp.unsqueeze(1) * ff_output ff_output = gate_mlp.unsqueeze(1) * ff_output
...@@ -643,7 +634,7 @@ class FeedForward(nn.Module): ...@@ -643,7 +634,7 @@ class FeedForward(nn.Module):
if inner_dim is None: if inner_dim is None:
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear linear_cls = nn.Linear
if activation_fn == "gelu": if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim, bias=bias) act_fn = GELU(dim, inner_dim, bias=bias)
...@@ -665,11 +656,10 @@ class FeedForward(nn.Module): ...@@ -665,11 +656,10 @@ class FeedForward(nn.Module):
if final_dropout: if final_dropout:
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
for module in self.net: for module in self.net:
if isinstance(module, compatible_cls): hidden_states = module(hidden_states)
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
return hidden_states return hidden_states
...@@ -20,10 +20,10 @@ import torch.nn.functional as F ...@@ -20,10 +20,10 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..image_processor import IPAdapterMaskProcessor from ..image_processor import IPAdapterMaskProcessor
from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils import deprecate, logging
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer from .lora import LoRALinearLayer
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -181,10 +181,7 @@ class Attention(nn.Module): ...@@ -181,10 +181,7 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
) )
if USE_PEFT_BACKEND: linear_cls = nn.Linear
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
self.linear_cls = linear_cls self.linear_cls = linear_cls
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
...@@ -741,11 +738,14 @@ class AttnProcessor: ...@@ -741,11 +738,14 @@ class AttnProcessor:
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, *args,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,) residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -764,15 +764,15 @@ class AttnProcessor: ...@@ -764,15 +764,15 @@ class AttnProcessor:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args) query = attn.to_q(hidden_states)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
...@@ -783,7 +783,7 @@ class AttnProcessor: ...@@ -783,7 +783,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -914,11 +914,14 @@ class AttnAddedKVProcessor: ...@@ -914,11 +914,14 @@ class AttnAddedKVProcessor:
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0, *args,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,) residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
...@@ -932,17 +935,17 @@ class AttnAddedKVProcessor: ...@@ -932,17 +935,17 @@ class AttnAddedKVProcessor:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args) query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, *args) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, *args) encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention: if not attn.only_cross_attention:
key = attn.to_k(hidden_states, *args) key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states, *args) value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value) value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
...@@ -956,7 +959,7 @@ class AttnAddedKVProcessor: ...@@ -956,7 +959,7 @@ class AttnAddedKVProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -984,11 +987,14 @@ class AttnAddedKVProcessor2_0: ...@@ -984,11 +987,14 @@ class AttnAddedKVProcessor2_0:
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
scale: float = 1.0, *args,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,) residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
...@@ -1002,7 +1008,7 @@ class AttnAddedKVProcessor2_0: ...@@ -1002,7 +1008,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args) query = attn.to_q(hidden_states)
query = attn.head_to_batch_dim(query, out_dim=4) query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
...@@ -1011,8 +1017,8 @@ class AttnAddedKVProcessor2_0: ...@@ -1011,8 +1017,8 @@ class AttnAddedKVProcessor2_0:
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4) encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention: if not attn.only_cross_attention:
key = attn.to_k(hidden_states, *args) key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states, *args) value = attn.to_v(hidden_states)
key = attn.head_to_batch_dim(key, out_dim=4) key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4) value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
...@@ -1029,7 +1035,7 @@ class AttnAddedKVProcessor2_0: ...@@ -1029,7 +1035,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1]) hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1132,11 +1138,14 @@ class XFormersAttnProcessor: ...@@ -1132,11 +1138,14 @@ class XFormersAttnProcessor:
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, *args,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
residual = hidden_states if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
args = () if USE_PEFT_BACKEND else (scale,) residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -1165,15 +1174,15 @@ class XFormersAttnProcessor: ...@@ -1165,15 +1174,15 @@ class XFormersAttnProcessor:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, *args) query = attn.to_q(hidden_states)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query).contiguous() query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous() key = attn.head_to_batch_dim(key).contiguous()
...@@ -1186,7 +1195,7 @@ class XFormersAttnProcessor: ...@@ -1186,7 +1195,7 @@ class XFormersAttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1217,8 +1226,13 @@ class AttnProcessor2_0: ...@@ -1217,8 +1226,13 @@ class AttnProcessor2_0:
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, *args,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -1242,16 +1256,15 @@ class AttnProcessor2_0: ...@@ -1242,16 +1256,15 @@ class AttnProcessor2_0:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,) query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, *args) key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states, *args) value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
...@@ -1271,7 +1284,7 @@ class AttnProcessor2_0: ...@@ -1271,7 +1284,7 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1312,8 +1325,13 @@ class FusedAttnProcessor2_0: ...@@ -1312,8 +1325,13 @@ class FusedAttnProcessor2_0:
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0, *args,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
residual = hidden_states residual = hidden_states
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -1337,17 +1355,16 @@ class FusedAttnProcessor2_0: ...@@ -1337,17 +1355,16 @@ class FusedAttnProcessor2_0:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
args = () if USE_PEFT_BACKEND else (scale,)
if encoder_hidden_states is None: if encoder_hidden_states is None:
qkv = attn.to_qkv(hidden_states, *args) qkv = attn.to_qkv(hidden_states)
split_size = qkv.shape[-1] // 3 split_size = qkv.shape[-1] // 3
query, key, value = torch.split(qkv, split_size, dim=-1) query, key, value = torch.split(qkv, split_size, dim=-1)
else: else:
if attn.norm_cross: if attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
query = attn.to_q(hidden_states, *args) query = attn.to_q(hidden_states)
kv = attn.to_kv(encoder_hidden_states, *args) kv = attn.to_kv(encoder_hidden_states)
split_size = kv.shape[-1] // 2 split_size = kv.shape[-1] // 2
key, value = torch.split(kv, split_size, dim=-1) key, value = torch.split(kv, split_size, dim=-1)
...@@ -1368,7 +1385,7 @@ class FusedAttnProcessor2_0: ...@@ -1368,7 +1385,7 @@ class FusedAttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, *args) hidden_states = attn.to_out[0](hidden_states)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1859,7 +1876,7 @@ class LoRAAttnProcessor(nn.Module): ...@@ -1859,7 +1876,7 @@ class LoRAAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -1877,7 +1894,7 @@ class LoRAAttnProcessor(nn.Module): ...@@ -1877,7 +1894,7 @@ class LoRAAttnProcessor(nn.Module):
attn._modules.pop("processor") attn._modules.pop("processor")
attn.processor = AttnProcessor() attn.processor = AttnProcessor()
return attn.processor(attn, hidden_states, *args, **kwargs) return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnProcessor2_0(nn.Module): class LoRAAttnProcessor2_0(nn.Module):
...@@ -1920,7 +1937,7 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1920,7 +1937,7 @@ class LoRAAttnProcessor2_0(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -1938,7 +1955,7 @@ class LoRAAttnProcessor2_0(nn.Module): ...@@ -1938,7 +1955,7 @@ class LoRAAttnProcessor2_0(nn.Module):
attn._modules.pop("processor") attn._modules.pop("processor")
attn.processor = AttnProcessor2_0() attn.processor = AttnProcessor2_0()
return attn.processor(attn, hidden_states, *args, **kwargs) return attn.processor(attn, hidden_states, **kwargs)
class LoRAXFormersAttnProcessor(nn.Module): class LoRAXFormersAttnProcessor(nn.Module):
...@@ -1999,7 +2016,7 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -1999,7 +2016,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha) self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha) self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -2017,7 +2034,7 @@ class LoRAXFormersAttnProcessor(nn.Module): ...@@ -2017,7 +2034,7 @@ class LoRAXFormersAttnProcessor(nn.Module):
attn._modules.pop("processor") attn._modules.pop("processor")
attn.processor = XFormersAttnProcessor() attn.processor = XFormersAttnProcessor()
return attn.processor(attn, hidden_states, *args, **kwargs) return attn.processor(attn, hidden_states, **kwargs)
class LoRAAttnAddedKVProcessor(nn.Module): class LoRAAttnAddedKVProcessor(nn.Module):
...@@ -2058,7 +2075,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -2058,7 +2075,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: def __call__(self, attn: Attention, hidden_states: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
self_cls_name = self.__class__.__name__ self_cls_name = self.__class__.__name__
deprecate( deprecate(
self_cls_name, self_cls_name,
...@@ -2076,7 +2093,7 @@ class LoRAAttnAddedKVProcessor(nn.Module): ...@@ -2076,7 +2093,7 @@ class LoRAAttnAddedKVProcessor(nn.Module):
attn._modules.pop("processor") attn._modules.pop("processor")
attn.processor = AttnAddedKVProcessor() attn.processor = AttnAddedKVProcessor()
return attn.processor(attn, hidden_states, *args, **kwargs) return attn.processor(attn, hidden_states, **kwargs)
class IPAdapterAttnProcessor(nn.Module): class IPAdapterAttnProcessor(nn.Module):
......
...@@ -18,8 +18,7 @@ import torch ...@@ -18,8 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND from ..utils import deprecate
from .lora import LoRACompatibleConv
from .normalization import RMSNorm from .normalization import RMSNorm
from .upsampling import upfirdn2d_native from .upsampling import upfirdn2d_native
...@@ -103,7 +102,7 @@ class Downsample2D(nn.Module): ...@@ -103,7 +102,7 @@ class Downsample2D(nn.Module):
self.padding = padding self.padding = padding
stride = 2 stride = 2
self.name = name self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d
if norm_type == "ln_norm": if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
...@@ -131,7 +130,10 @@ class Downsample2D(nn.Module): ...@@ -131,7 +130,10 @@ class Downsample2D(nn.Module):
else: else:
self.conv = conv self.conv = conv
def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.norm is not None: if self.norm is not None:
...@@ -143,13 +145,7 @@ class Downsample2D(nn.Module): ...@@ -143,13 +145,7 @@ class Downsample2D(nn.Module):
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if not USE_PEFT_BACKEND: hidden_states = self.conv(hidden_states)
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
return hidden_states return hidden_states
......
...@@ -18,10 +18,9 @@ import numpy as np ...@@ -18,10 +18,9 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND, deprecate from ..utils import deprecate
from .activations import get_activation from .activations import get_activation
from .attention_processor import Attention from .attention_processor import Attention
from .lora import LoRACompatibleLinear
def get_timestep_embedding( def get_timestep_embedding(
...@@ -200,7 +199,7 @@ class TimestepEmbedding(nn.Module): ...@@ -200,7 +199,7 @@ class TimestepEmbedding(nn.Module):
sample_proj_bias=True, sample_proj_bias=True,
): ):
super().__init__() super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear
self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias) self.linear_1 = linear_cls(in_channels, time_embed_dim, sample_proj_bias)
......
...@@ -204,6 +204,9 @@ class LoRALinearLayer(nn.Module): ...@@ -204,6 +204,9 @@ class LoRALinearLayer(nn.Module):
): ):
super().__init__() super().__init__()
deprecation_message = "Use of `LoRALinearLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRALinearLayer", "1.0.0", deprecation_message)
self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype) self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype) self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
...@@ -264,6 +267,9 @@ class LoRAConv2dLayer(nn.Module): ...@@ -264,6 +267,9 @@ class LoRAConv2dLayer(nn.Module):
): ):
super().__init__() super().__init__()
deprecation_message = "Use of `LoRAConv2dLayer` is deprecated. Please switch to PEFT backend by installing PEFT: `pip install peft`."
deprecate("LoRAConv2dLayer", "1.0.0", deprecation_message)
self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False) self.down = nn.Conv2d(in_features, rank, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
# according to the official kohya_ss trainer kernel_size are always fixed for the up layer # according to the official kohya_ss trainer kernel_size are always fixed for the up layer
# # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129 # # see: https://github.com/bmaltais/kohya_ss/blob/2accb1305979ba62f5077a23aabac23b4c37e935/networks/lora_diffusers.py#L129
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND from ..utils import deprecate
from .activations import get_activation from .activations import get_activation
from .attention_processor import SpatialNorm from .attention_processor import SpatialNorm
from .downsampling import ( # noqa from .downsampling import ( # noqa
...@@ -30,7 +30,6 @@ from .downsampling import ( # noqa ...@@ -30,7 +30,6 @@ from .downsampling import ( # noqa
KDownsample2D, KDownsample2D,
downsample_2d, downsample_2d,
) )
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .normalization import AdaGroupNorm from .normalization import AdaGroupNorm
from .upsampling import ( # noqa from .upsampling import ( # noqa
FirUpsample2D, FirUpsample2D,
...@@ -102,7 +101,7 @@ class ResnetBlockCondNorm2D(nn.Module): ...@@ -102,7 +101,7 @@ class ResnetBlockCondNorm2D(nn.Module):
self.output_scale_factor = output_scale_factor self.output_scale_factor = output_scale_factor
self.time_embedding_norm = time_embedding_norm self.time_embedding_norm = time_embedding_norm
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
...@@ -149,12 +148,11 @@ class ResnetBlockCondNorm2D(nn.Module): ...@@ -149,12 +148,11 @@ class ResnetBlockCondNorm2D(nn.Module):
bias=conv_shortcut_bias, bias=conv_shortcut_bias,
) )
def forward( def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self, if len(args) > 0 or kwargs.get("scale", None) is not None:
input_tensor: torch.FloatTensor, deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
temb: torch.FloatTensor, deprecate("scale", "1.0.0", deprecation_message)
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor hidden_states = input_tensor
hidden_states = self.norm1(hidden_states, temb) hidden_states = self.norm1(hidden_states, temb)
...@@ -166,26 +164,24 @@ class ResnetBlockCondNorm2D(nn.Module): ...@@ -166,26 +164,24 @@ class ResnetBlockCondNorm2D(nn.Module):
if hidden_states.shape[0] >= 64: if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous() input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor, scale=scale) input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states, scale=scale) hidden_states = self.upsample(hidden_states)
elif self.downsample is not None: elif self.downsample is not None:
input_tensor = self.downsample(input_tensor, scale=scale) input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states, scale=scale) hidden_states = self.downsample(hidden_states)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states, temb) hidden_states = self.norm2(hidden_states, temb)
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
input_tensor = ( input_tensor = self.conv_shortcut(input_tensor)
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
...@@ -267,8 +263,8 @@ class ResnetBlock2D(nn.Module): ...@@ -267,8 +263,8 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act self.skip_time_act = skip_time_act
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
...@@ -326,12 +322,11 @@ class ResnetBlock2D(nn.Module): ...@@ -326,12 +322,11 @@ class ResnetBlock2D(nn.Module):
bias=conv_shortcut_bias, bias=conv_shortcut_bias,
) )
def forward( def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
self, if len(args) > 0 or kwargs.get("scale", None) is not None:
input_tensor: torch.FloatTensor, deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
temb: torch.FloatTensor, deprecate("scale", "1.0.0", deprecation_message)
scale: float = 1.0,
) -> torch.FloatTensor:
hidden_states = input_tensor hidden_states = input_tensor
hidden_states = self.norm1(hidden_states) hidden_states = self.norm1(hidden_states)
...@@ -342,38 +337,18 @@ class ResnetBlock2D(nn.Module): ...@@ -342,38 +337,18 @@ class ResnetBlock2D(nn.Module):
if hidden_states.shape[0] >= 64: if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous() input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous() hidden_states = hidden_states.contiguous()
input_tensor = ( input_tensor = self.upsample(input_tensor)
self.upsample(input_tensor, scale=scale) hidden_states = self.upsample(hidden_states)
if isinstance(self.upsample, Upsample2D)
else self.upsample(input_tensor)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(hidden_states)
)
elif self.downsample is not None: elif self.downsample is not None:
input_tensor = ( input_tensor = self.downsample(input_tensor)
self.downsample(input_tensor, scale=scale) hidden_states = self.downsample(hidden_states)
if isinstance(self.downsample, Downsample2D)
else self.downsample(input_tensor)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(hidden_states)
)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states) hidden_states = self.conv1(hidden_states)
if self.time_emb_proj is not None: if self.time_emb_proj is not None:
if not self.skip_time_act: if not self.skip_time_act:
temb = self.nonlinearity(temb) temb = self.nonlinearity(temb)
temb = ( temb = self.time_emb_proj(temb)[:, :, None, None]
self.time_emb_proj(temb, scale)[:, :, None, None]
if not USE_PEFT_BACKEND
else self.time_emb_proj(temb)[:, :, None, None]
)
if self.time_embedding_norm == "default": if self.time_embedding_norm == "default":
if temb is not None: if temb is not None:
...@@ -393,12 +368,10 @@ class ResnetBlock2D(nn.Module): ...@@ -393,12 +368,10 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states) hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
input_tensor = ( input_tensor = self.conv_shortcut(input_tensor)
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
...@@ -19,14 +19,16 @@ import torch.nn.functional as F ...@@ -19,14 +19,16 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version from ...utils import BaseOutput, deprecate, is_torch_version, logging
from ..attention import BasicTransformerBlock from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass @dataclass
class Transformer2DModelOutput(BaseOutput): class Transformer2DModelOutput(BaseOutput):
""" """
...@@ -115,8 +117,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -115,8 +117,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration # Define whether input is continuous or discrete depending on configuration
...@@ -304,6 +306,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -304,6 +306,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor. `tuple` where the first element is the sample tensor.
""" """
if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension. # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
...@@ -327,9 +332,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -327,9 +332,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input # 1. Input
if self.is_input_continuous: if self.is_input_continuous:
batch, _, height, width = hidden_states.shape batch, _, height, width = hidden_states.shape
...@@ -337,21 +339,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -337,21 +339,13 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = ( hidden_states = self.proj_in(hidden_states)
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else: else:
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = ( hidden_states = self.proj_in(hidden_states)
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
...@@ -414,17 +408,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -414,17 +408,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = ( hidden_states = self.proj_out(hidden_states)
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else: else:
hidden_states = ( hidden_states = self.proj_out(hidden_states)
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual output = hidden_states + residual
......
...@@ -1226,7 +1226,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1226,7 +1226,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
**additional_residuals, **additional_residuals,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_intrablock_additional_residuals) > 0: if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0) sample += down_intrablock_additional_residuals.pop(0)
...@@ -1297,7 +1297,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, ...@@ -1297,7 +1297,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
temb=emb, temb=emb,
res_hidden_states_tuple=res_samples, res_hidden_states_tuple=res_samples,
upsample_size=upsample_size, upsample_size=upsample_size,
scale=lora_scale,
) )
# 6. post-process # 6. post-process
......
...@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union ...@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from ...utils import is_torch_version from ...utils import deprecate, is_torch_version, logging
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
from ..attention import Attention from ..attention import Attention
from ..resnet import ( from ..resnet import (
...@@ -35,6 +35,9 @@ from ..transformers.transformer_temporal import ( ...@@ -35,6 +35,9 @@ from ..transformers.transformer_temporal import (
) )
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def get_down_block( def get_down_block(
down_block_type: str, down_block_type: str,
num_layers: int, num_layers: int,
...@@ -1005,9 +1008,14 @@ class DownBlockMotion(nn.Module): ...@@ -1005,9 +1008,14 @@ class DownBlockMotion(nn.Module):
self, self,
hidden_states: torch.FloatTensor, hidden_states: torch.FloatTensor,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
num_frames: int = 1, num_frames: int = 1,
*args,
**kwargs,
) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
output_states = () output_states = ()
blocks = zip(self.resnets, self.motion_modules) blocks = zip(self.resnets, self.motion_modules)
...@@ -1029,18 +1037,18 @@ class DownBlockMotion(nn.Module): ...@@ -1029,18 +1037,18 @@ class DownBlockMotion(nn.Module):
) )
else: else:
hidden_states = torch.utils.checkpoint.checkpoint( hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, scale create_custom_forward(resnet), hidden_states, temb
) )
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale) hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -1173,9 +1181,11 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1173,9 +1181,11 @@ class CrossAttnDownBlockMotion(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
additional_residuals: Optional[torch.FloatTensor] = None, additional_residuals: Optional[torch.FloatTensor] = None,
): ):
output_states = () if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 output_states = ()
blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
for i, (resnet, attn, motion_module) in enumerate(blocks): for i, (resnet, attn, motion_module) in enumerate(blocks):
...@@ -1206,7 +1216,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1206,7 +1216,7 @@ class CrossAttnDownBlockMotion(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1228,7 +1238,7 @@ class CrossAttnDownBlockMotion(nn.Module): ...@@ -1228,7 +1238,7 @@ class CrossAttnDownBlockMotion(nn.Module):
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale) hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -1355,7 +1365,10 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1355,7 +1365,10 @@ class CrossAttnUpBlockMotion(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1, num_frames: int = 1,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -1410,7 +1423,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1410,7 +1423,7 @@ class CrossAttnUpBlockMotion(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1426,7 +1439,7 @@ class CrossAttnUpBlockMotion(nn.Module): ...@@ -1426,7 +1439,7 @@ class CrossAttnUpBlockMotion(nn.Module):
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states return hidden_states
...@@ -1507,9 +1520,14 @@ class UpBlockMotion(nn.Module): ...@@ -1507,9 +1520,14 @@ class UpBlockMotion(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
upsample_size=None, upsample_size=None,
scale: float = 1.0,
num_frames: int = 1, num_frames: int = 1,
*args,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -1559,12 +1577,12 @@ class UpBlockMotion(nn.Module): ...@@ -1559,12 +1577,12 @@ class UpBlockMotion(nn.Module):
) )
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb)
hidden_states = motion_module(hidden_states, num_frames=num_frames)[0] hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale) hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states return hidden_states
...@@ -1687,8 +1705,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1687,8 +1705,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
num_frames: int = 1, num_frames: int = 1,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if cross_attention_kwargs is not None:
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
for attn, resnet, motion_module in blocks: for attn, resnet, motion_module in blocks:
...@@ -1737,7 +1758,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module): ...@@ -1737,7 +1758,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
hidden_states, hidden_states,
num_frames=num_frames, num_frames=num_frames,
)[0] )[0]
hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
......
...@@ -89,7 +89,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module): ...@@ -89,7 +89,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
if hidden_states.ndim == 4: if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1) hidden_states = hidden_states.squeeze(1)
ff_output = self.ff(hidden_states, scale=1.0) ff_output = self.ff(hidden_states)
hidden_states = ff_output + hidden_states hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4: if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1) hidden_states = hidden_states.squeeze(1)
......
...@@ -18,8 +18,7 @@ import torch ...@@ -18,8 +18,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND from ..utils import deprecate
from .lora import LoRACompatibleConv
from .normalization import RMSNorm from .normalization import RMSNorm
...@@ -111,7 +110,7 @@ class Upsample2D(nn.Module): ...@@ -111,7 +110,7 @@ class Upsample2D(nn.Module):
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name self.name = name
self.interpolate = interpolate self.interpolate = interpolate
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d
if norm_type == "ln_norm": if norm_type == "ln_norm":
self.norm = nn.LayerNorm(channels, eps, elementwise_affine) self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
...@@ -141,11 +140,12 @@ class Upsample2D(nn.Module): ...@@ -141,11 +140,12 @@ class Upsample2D(nn.Module):
self.Conv2d_0 = conv self.Conv2d_0 = conv
def forward( def forward(
self, self, hidden_states: torch.FloatTensor, output_size: Optional[int] = None, *args, **kwargs
hidden_states: torch.FloatTensor,
output_size: Optional[int] = None,
scale: float = 1.0,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.norm is not None: if self.norm is not None:
...@@ -180,15 +180,9 @@ class Upsample2D(nn.Module): ...@@ -180,15 +180,9 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv: if self.use_conv:
if self.name == "conv": if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND: hidden_states = self.conv(hidden_states)
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else: else:
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND: hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states return hidden_states
......
...@@ -1333,7 +1333,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1333,7 +1333,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
**additional_residuals, **additional_residuals,
) )
else: else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
if is_adapter and len(down_intrablock_additional_residuals) > 0: if is_adapter and len(down_intrablock_additional_residuals) > 0:
sample += down_intrablock_additional_residuals.pop(0) sample += down_intrablock_additional_residuals.pop(0)
...@@ -1589,7 +1589,7 @@ class DownBlockFlat(nn.Module): ...@@ -1589,7 +1589,7 @@ class DownBlockFlat(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
...@@ -1611,13 +1611,13 @@ class DownBlockFlat(nn.Module): ...@@ -1611,13 +1611,13 @@ class DownBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb create_custom_forward(resnet), hidden_states, temb
) )
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=scale) hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -1728,8 +1728,6 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1728,8 +1728,6 @@ class CrossAttnDownBlockFlat(nn.Module):
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
output_states = () output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions)) blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks): for i, (resnet, attn) in enumerate(blocks):
...@@ -1760,7 +1758,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1760,7 +1758,7 @@ class CrossAttnDownBlockFlat(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -1778,7 +1776,7 @@ class CrossAttnDownBlockFlat(nn.Module): ...@@ -1778,7 +1776,7 @@ class CrossAttnDownBlockFlat(nn.Module):
if self.downsamplers is not None: if self.downsamplers is not None:
for downsampler in self.downsamplers: for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, scale=lora_scale) hidden_states = downsampler(hidden_states)
output_states = output_states + (hidden_states,) output_states = output_states + (hidden_states,)
...@@ -1842,8 +1840,13 @@ class UpBlockFlat(nn.Module): ...@@ -1842,8 +1840,13 @@ class UpBlockFlat(nn.Module):
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.FloatTensor] = None,
upsample_size: Optional[int] = None, upsample_size: Optional[int] = None,
scale: float = 1.0, *args,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
if len(args) > 0 or kwargs.get("scale", None) is not None:
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
deprecate("scale", "1.0.0", deprecation_message)
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -1887,11 +1890,11 @@ class UpBlockFlat(nn.Module): ...@@ -1887,11 +1890,11 @@ class UpBlockFlat(nn.Module):
create_custom_forward(resnet), hidden_states, temb create_custom_forward(resnet), hidden_states, temb
) )
else: else:
hidden_states = resnet(hidden_states, temb, scale=scale) hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=scale) hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states return hidden_states
...@@ -1999,7 +2002,10 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -1999,7 +2002,10 @@ class CrossAttnUpBlockFlat(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if cross_attention_kwargs is not None:
if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
is_freeu_enabled = ( is_freeu_enabled = (
getattr(self, "s1", None) getattr(self, "s1", None)
and getattr(self, "s2", None) and getattr(self, "s2", None)
...@@ -2053,7 +2059,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -2053,7 +2059,7 @@ class CrossAttnUpBlockFlat(nn.Module):
return_dict=False, return_dict=False,
)[0] )[0]
else: else:
hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb)
hidden_states = attn( hidden_states = attn(
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
...@@ -2065,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module): ...@@ -2065,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
if self.upsamplers is not None: if self.upsamplers is not None:
for upsampler in self.upsamplers: for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states return hidden_states
...@@ -2330,8 +2336,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2330,8 +2336,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 if cross_attention_kwargs is not None:
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing: if self.training and self.gradient_checkpointing:
...@@ -2368,7 +2377,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module): ...@@ -2368,7 +2377,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
...@@ -2469,7 +2478,8 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2469,7 +2478,8 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs.get("scale", None) is not None:
logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
if attention_mask is None: if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
...@@ -2482,7 +2492,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2482,7 +2492,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]): for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn # attn
hidden_states = attn( hidden_states = attn(
...@@ -2493,6 +2503,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module): ...@@ -2493,6 +2503,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
) )
# resnet # resnet
hidden_states = resnet(hidden_states, temb, scale=lora_scale) hidden_states = resnet(hidden_states, temb)
return hidden_states return hidden_states
...@@ -2,8 +2,6 @@ import torch ...@@ -2,8 +2,6 @@ import torch
import torch.nn as nn import torch.nn as nn
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...utils import USE_PEFT_BACKEND
class WuerstchenLayerNorm(nn.LayerNorm): class WuerstchenLayerNorm(nn.LayerNorm):
...@@ -19,7 +17,7 @@ class WuerstchenLayerNorm(nn.LayerNorm): ...@@ -19,7 +17,7 @@ class WuerstchenLayerNorm(nn.LayerNorm):
class TimestepBlock(nn.Module): class TimestepBlock(nn.Module):
def __init__(self, c, c_timestep): def __init__(self, c, c_timestep):
super().__init__() super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear
self.mapper = linear_cls(c_timestep, c * 2) self.mapper = linear_cls(c_timestep, c * 2)
def forward(self, x, t): def forward(self, x, t):
...@@ -31,8 +29,8 @@ class ResBlock(nn.Module): ...@@ -31,8 +29,8 @@ class ResBlock(nn.Module):
def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__() super().__init__()
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear
self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
...@@ -66,7 +64,7 @@ class AttnBlock(nn.Module): ...@@ -66,7 +64,7 @@ class AttnBlock(nn.Module):
def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0): def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
super().__init__() super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear
self.self_attn = self_attn self.self_attn = self_attn
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
......
...@@ -28,9 +28,8 @@ from ...models.attention_processor import ( ...@@ -28,9 +28,8 @@ from ...models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
from ...models.modeling_utils import ModelMixin from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version from ...utils import is_torch_version
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
...@@ -41,8 +40,8 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -41,8 +40,8 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
@register_to_config @register_to_config
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1): def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
super().__init__() super().__init__()
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv conv_cls = nn.Conv2d
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear linear_cls = nn.Linear
self.c_r = c_r self.c_r = c_r
self.projection = conv_cls(c_in, c, kernel_size=1) self.projection = conv_cls(c_in, c, kernel_size=1)
......
...@@ -22,7 +22,6 @@ from torch import nn ...@@ -22,7 +22,6 @@ from torch import nn
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformers.transformer_2d import Transformer2DModel from diffusers.models.transformers.transformer_2d import Transformer2DModel
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
...@@ -482,7 +481,7 @@ class Transformer2DModelTests(unittest.TestCase): ...@@ -482,7 +481,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32 dim = 32
inner_dim = 128 inner_dim = 128
...@@ -506,7 +505,7 @@ class Transformer2DModelTests(unittest.TestCase): ...@@ -506,7 +505,7 @@ class Transformer2DModelTests(unittest.TestCase):
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == LoRACompatibleLinear assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32 dim = 32
inner_dim = 128 inner_dim = 128
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment