Unverified Commit c81a88b2 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Core] LoRA improvements pt. 3 (#4842)



* throw warning when more than one lora is attempted to be fused.

* introduce support of lora scale during fusion.

* change test name

* changes

* change to _lora_scale

* lora_scale to call whenever applicable.

* debugging

* lora_scale additional.

* cross_attention_kwargs

* lora_scale -> scale.

* lora_scale fix

* lora_scale in patched projection.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* styling.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* remove unneeded prints.

* remove unneeded prints.

* assign cross_attention_kwargs.

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* debugging

* clean up.

* refactor scale retrieval logic a bit.

* fix nonetypw

* fix: tests

* add more tests

* more fixes.

* figure out a way to pass lora_scale.

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* unify the retrieval logic of lora_scale.

* move adjust_lora_scale_text_encoder to lora.py.

* introduce dynamic adjustment lora scale support to sd

* fix up copies

* Empty-Commit

* add: test to check fusion equivalence on different scales.

* handle lora fusion warning.

* make lora smaller

* make lora smaller

* make lora smaller

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2c1677ee
......@@ -95,7 +95,7 @@ class PatchedLoraProjection(nn.Module):
return super().state_dict(*args, destination=destination, prefix=prefix, keep_vars=keep_vars)
def _fuse_lora(self):
def _fuse_lora(self, lora_scale=1.0):
if self.lora_linear_layer is None:
return
......@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module):
if self.lora_linear_layer.network_alpha is not None:
w_up = w_up * self.lora_linear_layer.network_alpha / self.lora_linear_layer.rank
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
......@@ -117,6 +117,7 @@ class PatchedLoraProjection(nn.Module):
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self.lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
......@@ -128,16 +129,19 @@ class PatchedLoraProjection(nn.Module):
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
unfused_weight = fused_weight.float() - (self.lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.regular_linear_layer.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, input):
# print(f"{self.__class__.__name__} has a lora_scale of {self.lora_scale}")
if self.lora_scale is None:
self.lora_scale = 1.0
if self.lora_linear_layer is None:
return self.regular_linear_layer(input)
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
return self.regular_linear_layer(input) + (self.lora_scale * self.lora_linear_layer(input))
def text_encoder_attn_modules(text_encoder):
......@@ -576,12 +580,13 @@ class UNet2DConditionLoadersMixin:
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
def fuse_lora(self):
def fuse_lora(self, lora_scale=1.0):
self.lora_scale = lora_scale
self.apply(self._fuse_lora_apply)
def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora()
module._fuse_lora(self.lora_scale)
def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
......@@ -924,6 +929,7 @@ class LoraLoaderMixin:
"""
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
num_fused_loras = 0
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
......@@ -1807,7 +1813,7 @@ class LoraLoaderMixin:
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True):
def fuse_lora(self, fuse_unet: bool = True, fuse_text_encoder: bool = True, lora_scale: float = 1.0):
r"""
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
......@@ -1822,22 +1828,31 @@ class LoraLoaderMixin:
fuse_text_encoder (`bool`, defaults to `True`):
Whether to fuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
lora_scale (`float`, defaults to 1.0):
Controls how much to influence the outputs with the LoRA parameters.
"""
if fuse_unet or fuse_text_encoder:
self.num_fused_loras += 1
if self.num_fused_loras > 1:
logger.warn(
"The current API is supported for operating with a single LoRA file. You are trying to load and fuse more than one LoRA which is not well-supported.",
)
if fuse_unet:
self.unet.fuse_lora()
self.unet.fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora()
attn_module.k_proj._fuse_lora()
attn_module.v_proj._fuse_lora()
attn_module.out_proj._fuse_lora()
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora()
mlp_module.fc2._fuse_lora()
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
......@@ -1884,6 +1899,8 @@ class LoraLoaderMixin:
if hasattr(self, "text_encoder_2"):
unfuse_text_encoder_lora(self.text_encoder_2)
self.num_fused_loras -= 1
class FromSingleFileMixin:
"""
......
......@@ -177,7 +177,7 @@ class BasicTransformerBlock(nn.Module):
class_labels: Optional[torch.LongTensor] = None,
):
# Notice that normalization is always applied before the real computation in the following blocks.
# 1. Self-Attention
# 0. Self-Attention
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
......@@ -187,7 +187,10 @@ class BasicTransformerBlock(nn.Module):
else:
norm_hidden_states = self.norm1(hidden_states)
# 0. Prepare GLIGEN inputs
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
......@@ -201,12 +204,12 @@ class BasicTransformerBlock(nn.Module):
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
# 1.5 GLIGEN Control
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 1.5 ends
# 2.5 ends
# 2. Cross-Attention
# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
......@@ -220,7 +223,7 @@ class BasicTransformerBlock(nn.Module):
)
hidden_states = attn_output + hidden_states
# 3. Feed-forward
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
......@@ -235,11 +238,14 @@ class BasicTransformerBlock(nn.Module):
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
......@@ -295,9 +301,12 @@ class FeedForward(nn.Module):
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
def forward(self, hidden_states, scale: float = 1.0):
for module in self.net:
hidden_states = module(hidden_states)
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
return hidden_states
......@@ -342,8 +351,8 @@ class GEGLU(nn.Module):
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
......
......@@ -570,15 +570,15 @@ class AttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, lora_scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
......@@ -589,7 +589,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......@@ -722,17 +722,17 @@ class AttnAddedKVProcessor:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, lora_scale=scale)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, lora_scale=scale)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states, scale=scale)
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states, lora_scale=scale)
value = attn.to_v(hidden_states, lora_scale=scale)
key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
......@@ -746,7 +746,7 @@ class AttnAddedKVProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......@@ -782,7 +782,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)
query = attn.head_to_batch_dim(query, out_dim=4)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
......@@ -791,8 +791,8 @@ class AttnAddedKVProcessor2_0:
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj, out_dim=4)
if not attn.only_cross_attention:
key = attn.to_k(hidden_states, lora_scale=scale)
value = attn.to_v(hidden_states, lora_scale=scale)
key = attn.to_k(hidden_states, scale=scale)
value = attn.to_v(hidden_states, scale=scale)
key = attn.head_to_batch_dim(key, out_dim=4)
value = attn.head_to_batch_dim(value, out_dim=4)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
......@@ -809,7 +809,7 @@ class AttnAddedKVProcessor2_0:
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, residual.shape[1])
# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......@@ -937,15 +937,15 @@ class XFormersAttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, lora_scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
query = attn.head_to_batch_dim(query).contiguous()
key = attn.head_to_batch_dim(key).contiguous()
......@@ -958,7 +958,7 @@ class XFormersAttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......@@ -1015,15 +1015,15 @@ class AttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, lora_scale=scale)
query = attn.to_q(hidden_states, scale=scale)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, lora_scale=scale)
value = attn.to_v(encoder_hidden_states, lora_scale=scale)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
......@@ -1043,7 +1043,7 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, lora_scale=scale)
hidden_states = attn.to_out[0](hidden_states, scale=scale)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......
......@@ -18,12 +18,27 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..loaders import PatchedLoraProjection, text_encoder_attn_modules, text_encoder_mlp_modules
from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4, network_alpha=None, device=None, dtype=None):
super().__init__()
......@@ -97,12 +112,11 @@ class LoRACompatibleConv(nn.Conv2d):
def set_lora_layer(self, lora_layer: Optional[LoRAConv2dLayer]):
self.lora_layer = lora_layer
def _fuse_lora(self):
def _fuse_lora(self, lora_scale=1.0):
if self.lora_layer is None:
return
dtype, device = self.weight.data.dtype, self.weight.data.device
logger.info(f"Fusing LoRA weights for {self.__class__}")
w_orig = self.weight.data.float()
w_up = self.lora_layer.up.weight.data.float()
......@@ -113,7 +127,7 @@ class LoRACompatibleConv(nn.Conv2d):
fusion = torch.mm(w_up.flatten(start_dim=1), w_down.flatten(start_dim=1))
fusion = fusion.reshape((w_orig.shape))
fused_weight = w_orig + fusion
fused_weight = w_orig + (lora_scale * fusion)
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
......@@ -122,33 +136,35 @@ class LoRACompatibleConv(nn.Conv2d):
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
return
logger.info(f"Unfusing LoRA weights for {self.__class__}")
fused_weight = self.weight.data
dtype, device = fused_weight.data.dtype, fused_weight.data.device
self.w_up = self.w_up.to(device=device, dtype=dtype)
self.w_down = self.w_down.to(device, dtype=dtype)
self.w_up = self.w_up.to(device=device).float()
self.w_down = self.w_down.to(device).float()
fusion = torch.mm(self.w_up.flatten(start_dim=1), self.w_down.flatten(start_dim=1))
fusion = fusion.reshape((fused_weight.shape))
unfused_weight = fused_weight - fusion
unfused_weight = fused_weight.float() - (self._lora_scale * fusion)
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, x):
def forward(self, hidden_states, scale: float = 1.0):
if self.lora_layer is None:
# make sure to the functional Conv2D function as otherwise torch.compile's graph will break
# see: https://github.com/huggingface/diffusers/pull/4315
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
return F.conv2d(
hidden_states, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
)
else:
return super().forward(x) + self.lora_layer(x)
return super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
class LoRACompatibleLinear(nn.Linear):
......@@ -163,7 +179,7 @@ class LoRACompatibleLinear(nn.Linear):
def set_lora_layer(self, lora_layer: Optional[LoRALinearLayer]):
self.lora_layer = lora_layer
def _fuse_lora(self):
def _fuse_lora(self, lora_scale=1.0):
if self.lora_layer is None:
return
......@@ -176,7 +192,7 @@ class LoRACompatibleLinear(nn.Linear):
if self.lora_layer.network_alpha is not None:
w_up = w_up * self.lora_layer.network_alpha / self.lora_layer.rank
fused_weight = w_orig + torch.bmm(w_up[None, :], w_down[None, :])[0]
fused_weight = w_orig + (lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.weight.data = fused_weight.to(device=device, dtype=dtype)
# we can drop the lora layer now
......@@ -185,6 +201,7 @@ class LoRACompatibleLinear(nn.Linear):
# offload the up and down matrices to CPU to not blow the memory
self.w_up = w_up.cpu()
self.w_down = w_down.cpu()
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
......@@ -196,14 +213,16 @@ class LoRACompatibleLinear(nn.Linear):
w_up = self.w_up.to(device=device).float()
w_down = self.w_down.to(device).float()
unfused_weight = fused_weight.float() - torch.bmm(w_up[None, :], w_down[None, :])[0]
unfused_weight = fused_weight.float() - (self._lora_scale * torch.bmm(w_up[None, :], w_down[None, :])[0])
self.weight.data = unfused_weight.to(device=device, dtype=dtype)
self.w_up = None
self.w_down = None
def forward(self, hidden_states, lora_scale: int = 1):
def forward(self, hidden_states, scale: float = 1.0):
if self.lora_layer is None:
return super().forward(hidden_states)
out = super().forward(hidden_states)
return out
else:
return super().forward(hidden_states) + lora_scale * self.lora_layer(hidden_states)
out = super().forward(hidden_states) + (scale * self.lora_layer(hidden_states))
return out
......@@ -135,7 +135,7 @@ class Upsample2D(nn.Module):
else:
self.Conv2d_0 = conv
def forward(self, hidden_states, output_size=None):
def forward(self, hidden_states, output_size=None, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
......@@ -166,9 +166,15 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
hidden_states = self.conv(hidden_states)
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.Conv2d_0(hidden_states)
if isinstance(self.Conv2d_0, LoRACompatibleConv):
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
return hidden_states
......@@ -211,14 +217,17 @@ class Downsample2D(nn.Module):
else:
self.conv = conv
def forward(self, hidden_states):
def forward(self, hidden_states, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
return hidden_states
......@@ -588,7 +597,7 @@ class ResnetBlock2D(nn.Module):
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
def forward(self, input_tensor, temb):
def forward(self, input_tensor, temb, scale: float = 1.0):
hidden_states = input_tensor
if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
......@@ -603,18 +612,34 @@ class ResnetBlock2D(nn.Module):
if hidden_states.shape[0] >= 64:
input_tensor = input_tensor.contiguous()
hidden_states = hidden_states.contiguous()
input_tensor = self.upsample(input_tensor)
hidden_states = self.upsample(hidden_states)
input_tensor = (
self.upsample(input_tensor, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(input_tensor)
)
hidden_states = (
self.upsample(hidden_states, scale=scale)
if isinstance(self.upsample, Upsample2D)
else self.upsample(hidden_states)
)
elif self.downsample is not None:
input_tensor = self.downsample(input_tensor)
hidden_states = self.downsample(hidden_states)
input_tensor = (
self.downsample(input_tensor, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(input_tensor)
)
hidden_states = (
self.downsample(hidden_states, scale=scale)
if isinstance(self.downsample, Downsample2D)
else self.downsample(hidden_states)
)
hidden_states = self.conv1(hidden_states)
hidden_states = self.conv1(hidden_states, scale)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb)[:, :, None, None]
temb = self.time_emb_proj(temb, scale)[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
......@@ -631,10 +656,10 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.conv2(hidden_states, scale)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
input_tensor = self.conv_shortcut(input_tensor, scale)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
......@@ -274,6 +274,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
......@@ -281,13 +284,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(hidden_states, lora_scale)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
......@@ -322,9 +326,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
......
......@@ -640,7 +640,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
hidden_states = self.resnets[0](hidden_states, temb)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.training and self.gradient_checkpointing:
......@@ -677,7 +678,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
encoder_attention_mask=encoder_attention_mask,
return_dict=False,
)[0]
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
......@@ -777,6 +778,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
......@@ -789,7 +791,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
# mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
mask = attention_mask
hidden_states = self.resnets[0](hidden_states, temb)
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn
hidden_states = attn(
......@@ -800,7 +802,7 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
)
# resnet
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
return hidden_states
......@@ -897,20 +899,25 @@ class AttnDownBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states, temb=None, upsample_size=None):
def forward(self, hidden_states, temb=None, upsample_size=None, cross_attention_kwargs=None):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
cross_attention_kwargs.update({"scale": lora_scale})
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
if self.downsample_type == "resnet":
hidden_states = downsampler(hidden_states, temb=temb)
hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
else:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states += (hidden_states,)
......@@ -1019,6 +1026,8 @@ class CrossAttnDownBlock2D(nn.Module):
):
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
blocks = list(zip(self.resnets, self.attentions))
for i, (resnet, attn) in enumerate(blocks):
......@@ -1049,7 +1058,7 @@ class CrossAttnDownBlock2D(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -1067,7 +1076,7 @@ class CrossAttnDownBlock2D(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=lora_scale)
output_states = output_states + (hidden_states,)
......@@ -1126,7 +1135,7 @@ class DownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
......@@ -1147,13 +1156,13 @@ class DownBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale=scale)
output_states = output_states + (hidden_states,)
......@@ -1209,13 +1218,13 @@ class DownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states):
def forward(self, hidden_states, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=None)
hidden_states = resnet(hidden_states, temb=None, scale=scale)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale)
return hidden_states
......@@ -1292,14 +1301,15 @@ class AttnDownEncoderBlock2D(nn.Module):
else:
self.downsamplers = None
def forward(self, hidden_states):
def forward(self, hidden_states, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=None)
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb=None, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
hidden_states = downsampler(hidden_states, scale)
return hidden_states
......@@ -1385,16 +1395,17 @@ class AttnSkipDownBlock2D(nn.Module):
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
......@@ -1465,15 +1476,15 @@ class SkipDownBlock2D(nn.Module):
self.downsamplers = None
self.skip_conv = None
def forward(self, hidden_states, temb=None, skip_sample=None):
def forward(self, hidden_states, temb=None, skip_sample=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale)
output_states += (hidden_states,)
if self.downsamplers is not None:
hidden_states = self.resnet_down(hidden_states, temb)
hidden_states = self.resnet_down(hidden_states, temb, scale)
for downsampler in self.downsamplers:
skip_sample = downsampler(skip_sample)
......@@ -1548,7 +1559,7 @@ class ResnetDownsampleBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
......@@ -1569,13 +1580,13 @@ class ResnetDownsampleBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale)
output_states = output_states + (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)
hidden_states = downsampler(hidden_states, temb, scale)
output_states = output_states + (hidden_states,)
......@@ -1689,6 +1700,8 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
output_states = ()
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
......@@ -1720,7 +1733,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
......@@ -1733,7 +1746,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)
hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
output_states = output_states + (hidden_states,)
......@@ -1786,7 +1799,7 @@ class KDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, scale: float = 1.0):
output_states = ()
for resnet in self.resnets:
......@@ -1807,7 +1820,7 @@ class KDownBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale)
output_states += (hidden_states,)
......@@ -1893,6 +1906,7 @@ class KCrossAttnDownBlock2D(nn.Module):
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
output_states = ()
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
......@@ -1922,7 +1936,7 @@ class KCrossAttnDownBlock2D(nn.Module):
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -2033,22 +2047,23 @@ class AttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = resnet(hidden_states, temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, **cross_attention_kwargs)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
if self.upsample_type == "resnet":
hidden_states = upsampler(hidden_states, temb=temb)
hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
else:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, scale=scale)
return hidden_states
......@@ -2150,6 +2165,8 @@ class CrossAttnUpBlock2D(nn.Module):
attention_mask: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
):
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......@@ -2183,7 +2200,7 @@ class CrossAttnUpBlock2D(nn.Module):
return_dict=False,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......@@ -2195,7 +2212,7 @@ class CrossAttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
return hidden_states
......@@ -2248,7 +2265,7 @@ class UpBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......@@ -2272,11 +2289,11 @@ class UpBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
return hidden_states
......@@ -2325,9 +2342,9 @@ class UpDecoderBlock2D(nn.Module):
else:
self.upsamplers = None
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
......@@ -2404,14 +2421,15 @@ class AttnUpDecoderBlock2D(nn.Module):
else:
self.upsamplers = None
def forward(self, hidden_states, temb=None):
def forward(self, hidden_states, temb=None, scale: float = 1.0):
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb=temb)
hidden_states = attn(hidden_states, temb=temb)
hidden_states = resnet(hidden_states, temb=temb, scale=scale)
cross_attention_kwargs = {"scale": scale}
hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, scale=scale)
return hidden_states
......@@ -2507,16 +2525,17 @@ class AttnSkipUpBlock2D(nn.Module):
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
hidden_states = self.attentions[0](hidden_states)
cross_attention_kwargs = {"scale": scale}
hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
......@@ -2530,7 +2549,7 @@ class AttnSkipUpBlock2D(nn.Module):
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
return hidden_states, skip_sample
......@@ -2604,14 +2623,14 @@ class SkipUpBlock2D(nn.Module):
self.skip_norm = None
self.act = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, skip_sample=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if skip_sample is not None:
skip_sample = self.upsampler(skip_sample)
......@@ -2625,7 +2644,7 @@ class SkipUpBlock2D(nn.Module):
skip_sample = skip_sample + skip_sample_states
hidden_states = self.resnet_up(hidden_states, temb)
hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
return hidden_states, skip_sample
......@@ -2697,7 +2716,7 @@ class ResnetUpsampleBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
......@@ -2721,11 +2740,11 @@ class ResnetUpsampleBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb)
hidden_states = upsampler(hidden_states, temb, scale=scale)
return hidden_states
......@@ -2840,6 +2859,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
):
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
lora_scale = cross_attention_kwargs.get("scale", 1.0)
if attention_mask is None:
# if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
mask = None if encoder_hidden_states is None else encoder_attention_mask
......@@ -2877,7 +2897,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
**cross_attention_kwargs,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
......@@ -2888,7 +2908,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb)
hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
return hidden_states
......@@ -2941,7 +2961,7 @@ class KUpBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale: float = 1.0):
res_hidden_states_tuple = res_hidden_states_tuple[-1]
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
......@@ -2964,7 +2984,7 @@ class KUpBlock2D(nn.Module):
create_custom_forward(resnet), hidden_states, temb
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=scale)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
......@@ -3072,6 +3092,7 @@ class KCrossAttnUpBlock2D(nn.Module):
if res_hidden_states_tuple is not None:
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
......@@ -3100,7 +3121,7 @@ class KCrossAttnUpBlock2D(nn.Module):
encoder_attention_mask=encoder_attention_mask,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
......
......@@ -934,6 +934,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
......@@ -956,7 +957,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
**additional_residuals,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
if is_adapter and len(down_block_additional_residuals) > 0:
sample += down_block_additional_residuals.pop(0)
......@@ -1020,7 +1021,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
)
else:
sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
scale=lora_scale,
)
# 6. post-process
......
......@@ -25,6 +25,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
......@@ -322,6 +323,9 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -27,6 +27,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
......@@ -320,6 +321,9 @@ class AltDiffusionImg2ImgPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -25,6 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
......@@ -312,6 +313,9 @@ class StableDiffusionControlNetPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
......@@ -25,6 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
......@@ -337,6 +337,9 @@ class StableDiffusionControlNetImg2ImgPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -26,6 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
......@@ -463,6 +464,9 @@ class StableDiffusionControlNetInpaintPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -33,6 +33,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
......@@ -342,6 +343,10 @@ class StableDiffusionXLControlNetInpaintPipeline(DiffusionPipeline, LoraLoaderMi
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -34,6 +34,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
......@@ -315,6 +316,10 @@ class StableDiffusionXLControlNetPipeline(
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -33,6 +33,7 @@ from ...models.attention_processor import (
LoRAXFormersAttnProcessor,
XFormersAttnProcessor,
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
is_accelerate_available,
......@@ -352,6 +353,10 @@ class StableDiffusionXLControlNetImg2ImgPipeline(DiffusionPipeline, TextualInver
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -27,6 +27,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
......@@ -329,6 +330,9 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -23,6 +23,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
deprecate,
......@@ -321,6 +322,9 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -25,6 +25,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
......@@ -321,6 +322,9 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
......@@ -26,6 +26,7 @@ from ...configuration_utils import FrozenDict
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -203,6 +204,9 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment