"torchvision/vscode:/vscode.git/clone" did not exist on "249cfbf5fef41ee30f380d895ad451b0d16d2947"
Unverified Commit 2bfa55f4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `PEFT` / `LoRA`] Integrate PEFT into Unet (#5151)



* v1

* add tests and fix previous failing tests

* fix CI

* add tests + v1 `PeftLayerScaler`

* style

* add scale retrieving mechanism system

* fix CI

* up

* up

* simple approach --> not same results for some reason

* fix issues

* fix copies

* remove unneeded method

* active adapters!

* fix merge conflicts

* up

* up

* kohya - test-1

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix scale

* fix copies

* add comment

* multi adapters

* fix tests

* oops

* v1 faster loading - in progress

* Revert "v1 faster loading - in progress"

This reverts commit ac925f81321e95fc8168184c3346bf3d75404d5a.

* kohya same generation

* fix some slow tests

* peft integration features for unet lora

1. Support for Multiple ranks/alphas
2. Support for Multiple active adapters
3. Support for enabling/disabling LoRAs

* fix `get_peft_kwargs`

* Update loaders.py

* add some tests

* add unfuse tests

* fix tests

* up

* add set adapter from sourab and tests

* fix multi adapter tests

* style & quality

* style

* remove comment

* fix `adapter_name` issues

* fix unet adapter name for sdxl

* fix enabling/disabling adapters

* fix fuse / unfuse unet

* nit

* fix

* up

* fix cpu offloading

* fix another slow test

* fix another offload test

* add more tests

* all slow tests pass

* style

* fix alpha pattern for unet and text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* up

* up

* clarify comment

* comments

* change comment order

* change comment order

* stylr & quality

* Update tests/lora/test_lora_layers_peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix bugs and add tests

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* refactor

* suggestion

* add break statemebt

* add compile tests

* move slow tests to peft tests as I modified them

* quality

* refactor a bit

* style

* change import

* style

* fix CI

* refactor slow tests one last time

* style

* oops

* oops

* oops

* final tweak tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* comments

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* remove comments

* more comments

* try

* revert

* add `safe_merge` tests

* add comment

* style, comments and run tests in fp16

* add warnings

* fix doc test

* replace with `adapter_weights`

* add `get_active_adapters()`

* expose `get_list_adapters` method

* better error message

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* trigger slow lora tests

* fix tests

* maybe fix last test

* revert

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* move `MIN_PEFT_VERSION`

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* let's not use class variable

* fix few nits

* change a bit offloading logic

* check earlier

* rm unneeded block

* break long line

* return empty list

* change logic a bit and address comments

* add typehint

* remove parenthesis

* fix

* revert to fp16 in tests

* add to gpu

* revert to old test

* style

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* change indent

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 9bc55e8b
This diff is collapsed.
......@@ -17,6 +17,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import get_activation
from .attention_processor import Attention
......@@ -300,6 +301,7 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
......@@ -316,14 +318,15 @@ class FeedForward(nn.Module):
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
self.net.append(linear_cls(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
......@@ -368,7 +371,9 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
......@@ -377,7 +382,8 @@ class GEGLU(nn.Module):
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
......
......@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer
......@@ -137,22 +137,27 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
......@@ -545,6 +550,8 @@ class AttnProcessor:
):
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
......@@ -562,15 +569,15 @@ class AttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
......@@ -581,7 +588,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......@@ -1007,15 +1014,20 @@ class AttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = (
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
)
value = (
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
......@@ -1035,7 +1047,9 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = (
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......
......@@ -18,6 +18,7 @@ import numpy as np
import torch
from torch import nn
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .lora import LoRACompatibleLinear
......@@ -166,8 +167,9 @@ class TimestepEmbedding(nn.Module):
cond_proj_dim=None,
):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
self.linear_1 = linear_cls(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
......@@ -180,7 +182,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
......
......@@ -32,10 +32,12 @@ from ..utils import (
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_add_variant,
_get_model_file,
check_peft_version,
deprecate,
is_accelerate_available,
is_torch_version,
......@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_hf_peft_config_loaded = False
def __init__(self):
super().__init__()
......@@ -292,6 +295,153 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
r"""
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
to the adapter to follow the convention of the PEFT library.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
[documentation](https://huggingface.co/docs/peft).
Args:
adapter_config (`[~peft.PeftConfig]`):
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
methods.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
from peft import PeftConfig, inject_adapter_in_model
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
if not isinstance(adapter_config, PeftConfig):
raise ValueError(
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
adapter_config.base_model_name_or_path = None
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Args:
adapter_name (Union[str, List[str]])):
The list of adapters to set or the adapter name in case of single adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
if isinstance(adapter_name, str):
adapter_name = [adapter_name]
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
from peft.tuners.tuners_utils import BaseTunerLayer
_adapters_has_been_set = False
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
# Previous versions of PEFT does not support multi-adapter inference
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
raise ValueError(
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True
if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)
def disable_adapters(self) -> None:
r"""
Disable all adapters attached to the model and fallback to inference with the base model only.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
# support for older PEFT versions
module.disable_adapters = True
def enable_adapters(self) -> None:
"""
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
list of adapters to enable.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
# support for older PEFT versions
module.disable_adapters = False
def active_adapters(self) -> List[str]:
"""
Gets the current list of active adapters of the model.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
......
......@@ -20,6 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm
......@@ -149,12 +150,13 @@ class Upsample2D(nn.Module):
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv = None
if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
......@@ -193,12 +195,12 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv):
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv):
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
......@@ -237,9 +239,10 @@ class Downsample2D(nn.Module):
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
......@@ -255,15 +258,20 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
return hidden_states
......@@ -608,6 +616,9 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if groups_out is None:
groups_out = groups
......@@ -618,13 +629,13 @@ class ResnetBlock2D(nn.Module):
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
......@@ -641,7 +652,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
......@@ -667,7 +678,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = LoRACompatibleConv(
self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
......@@ -708,12 +719,16 @@ class ResnetBlock2D(nn.Module):
else self.downsample(hidden_states)
)
hidden_states = self.conv1(hidden_states, scale)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb, scale)[:, :, None, None]
temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
if not USE_PEFT_BACKEND
else self.time_emb_proj(temb)[:, :, None, None]
)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
......@@ -730,10 +745,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, scale)
input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
......@@ -20,7 +20,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
......@@ -100,6 +100,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
......@@ -139,9 +142,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
self.proj_in = linear_cls(in_channels, inner_dim)
else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
......@@ -197,9 +200,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
self.proj_out = linear_cls(inner_dim, in_channels)
else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
......@@ -292,13 +295,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
......@@ -334,9 +345,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
......
......@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
......@@ -995,6 +995,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
......@@ -1094,6 +1097,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
if not return_dict:
return (sample,)
......
......@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -304,7 +311,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -432,7 +439,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......@@ -668,6 +675,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
self.check_inputs(
......@@ -689,9 +697,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
......@@ -700,7 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
......
......@@ -29,6 +29,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
......@@ -309,7 +310,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -437,7 +438,7 @@ class AltDiffusionImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -27,7 +27,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
......@@ -287,7 +294,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -415,7 +422,7 @@ class StableDiffusionControlNetPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
......@@ -317,7 +318,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -445,7 +446,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -28,7 +28,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
......@@ -438,7 +445,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -566,7 +573,7 @@ class StableDiffusionControlNetInpaintPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -33,6 +33,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_invisible_watermark_available,
logging,
replace_example_docstring,
......@@ -316,7 +317,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
......@@ -458,7 +459,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
bs_embed * num_images_per_prompt, -1
)
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
......
......@@ -35,7 +35,7 @@ from ...models.attention_processor import (
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
......@@ -285,7 +285,7 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
......@@ -427,7 +427,7 @@ class StableDiffusionXLControlNetPipeline(
bs_embed * num_images_per_prompt, -1
)
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
......
......@@ -36,6 +36,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -328,7 +329,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
......@@ -470,7 +471,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
bs_embed * num_images_per_prompt, -1
)
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
......
......@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput
......@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -436,7 +436,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -25,7 +25,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput
......@@ -297,7 +304,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -425,7 +432,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......@@ -658,6 +665,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
self.check_inputs(
......@@ -679,9 +687,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
......@@ -690,7 +697,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
......
......@@ -27,7 +27,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
......@@ -332,7 +339,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -460,7 +467,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -28,7 +28,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -341,7 +341,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment