"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "291ecdacd3a252692e889d99e57be531f57ced12"
Unverified Commit 2bfa55f4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `PEFT` / `LoRA`] Integrate PEFT into Unet (#5151)



* v1

* add tests and fix previous failing tests

* fix CI

* add tests + v1 `PeftLayerScaler`

* style

* add scale retrieving mechanism system

* fix CI

* up

* up

* simple approach --> not same results for some reason

* fix issues

* fix copies

* remove unneeded method

* active adapters!

* fix merge conflicts

* up

* up

* kohya - test-1

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix scale

* fix copies

* add comment

* multi adapters

* fix tests

* oops

* v1 faster loading - in progress

* Revert "v1 faster loading - in progress"

This reverts commit ac925f81321e95fc8168184c3346bf3d75404d5a.

* kohya same generation

* fix some slow tests

* peft integration features for unet lora

1. Support for Multiple ranks/alphas
2. Support for Multiple active adapters
3. Support for enabling/disabling LoRAs

* fix `get_peft_kwargs`

* Update loaders.py

* add some tests

* add unfuse tests

* fix tests

* up

* add set adapter from sourab and tests

* fix multi adapter tests

* style & quality

* style

* remove comment

* fix `adapter_name` issues

* fix unet adapter name for sdxl

* fix enabling/disabling adapters

* fix fuse / unfuse unet

* nit

* fix

* up

* fix cpu offloading

* fix another slow test

* fix another offload test

* add more tests

* all slow tests pass

* style

* fix alpha pattern for unet and text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* up

* up

* clarify comment

* comments

* change comment order

* change comment order

* stylr & quality

* Update tests/lora/test_lora_layers_peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix bugs and add tests

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* refactor

* suggestion

* add break statemebt

* add compile tests

* move slow tests to peft tests as I modified them

* quality

* refactor a bit

* style

* change import

* style

* fix CI

* refactor slow tests one last time

* style

* oops

* oops

* oops

* final tweak tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* comments

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* remove comments

* more comments

* try

* revert

* add `safe_merge` tests

* add comment

* style, comments and run tests in fp16

* add warnings

* fix doc test

* replace with `adapter_weights`

* add `get_active_adapters()`

* expose `get_list_adapters` method

* better error message

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* trigger slow lora tests

* fix tests

* maybe fix last test

* revert

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* move `MIN_PEFT_VERSION`

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* let's not use class variable

* fix few nits

* change a bit offloading logic

* check earlier

* rm unneeded block

* break long line

* return empty list

* change logic a bit and address comments

* add typehint

* remove parenthesis

* fix

* revert to fp16 in tests

* add to gpu

* revert to old test

* style

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* change indent

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 9bc55e8b
This diff is collapsed.
...@@ -17,6 +17,7 @@ import torch ...@@ -17,6 +17,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .activations import get_activation from .activations import get_activation
from .attention_processor import Attention from .attention_processor import Attention
...@@ -300,6 +301,7 @@ class FeedForward(nn.Module): ...@@ -300,6 +301,7 @@ class FeedForward(nn.Module):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu": if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim) act_fn = GELU(dim, inner_dim)
...@@ -316,14 +318,15 @@ class FeedForward(nn.Module): ...@@ -316,14 +318,15 @@ class FeedForward(nn.Module):
# project dropout # project dropout
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
# project out # project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) self.net.append(linear_cls(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout: if final_dropout:
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net: for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)): if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale) hidden_states = module(hidden_states, scale)
else: else:
hidden_states = module(hidden_states) hidden_states = module(hidden_states)
...@@ -368,7 +371,9 @@ class GEGLU(nn.Module): ...@@ -368,7 +371,9 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int): def __init__(self, dim_in: int, dim_out: int):
super().__init__() super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2)
def gelu(self, gate: torch.Tensor) -> torch.Tensor: def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps": if gate.device.type != "mps":
...@@ -377,7 +382,8 @@ class GEGLU(nn.Module): ...@@ -377,7 +382,8 @@ class GEGLU(nn.Module):
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate) return hidden_states * self.gelu(gate)
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import deprecate, logging from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer from .lora import LoRACompatibleLinear, LoRALinearLayer
...@@ -137,22 +137,27 @@ class Attention(nn.Module): ...@@ -137,22 +137,27 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
) )
self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention: if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes # only relevant for the `AddedKVProcessor` classes
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
else: else:
self.to_k = None self.to_k = None
self.to_v = None self.to_v = None
if self.added_kv_proj_dim is not None: if self.added_kv_proj_dim is not None:
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([]) self.to_out = nn.ModuleList([])
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout)) self.to_out.append(nn.Dropout(dropout))
# set attention processor # set attention processor
...@@ -545,6 +550,8 @@ class AttnProcessor: ...@@ -545,6 +550,8 @@ class AttnProcessor:
): ):
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -562,15 +569,15 @@ class AttnProcessor: ...@@ -562,15 +569,15 @@ class AttnProcessor:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale) query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale) key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, scale=scale) value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
...@@ -581,7 +588,7 @@ class AttnProcessor: ...@@ -581,7 +588,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale) hidden_states = attn.to_out[0](hidden_states, *args)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1007,15 +1014,20 @@ class AttnProcessor2_0: ...@@ -1007,15 +1014,20 @@ class AttnProcessor2_0:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale) args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale) key = (
value = attn.to_v(encoder_hidden_states, scale=scale) attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
)
value = (
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
...@@ -1035,7 +1047,9 @@ class AttnProcessor2_0: ...@@ -1035,7 +1047,9 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale) hidden_states = (
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation from .activations import get_activation
from .lora import LoRACompatibleLinear from .lora import LoRACompatibleLinear
...@@ -166,8 +167,9 @@ class TimestepEmbedding(nn.Module): ...@@ -166,8 +167,9 @@ class TimestepEmbedding(nn.Module):
cond_proj_dim=None, cond_proj_dim=None,
): ):
super().__init__() super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim) self.linear_1 = linear_cls(in_channels, time_embed_dim)
if cond_proj_dim is not None: if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
...@@ -180,7 +182,7 @@ class TimestepEmbedding(nn.Module): ...@@ -180,7 +182,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim time_embed_dim_out = out_dim
else: else:
time_embed_dim_out = time_embed_dim time_embed_dim_out = time_embed_dim
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out) self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
if post_act_fn is None: if post_act_fn is None:
self.post_act = None self.post_act = None
......
...@@ -32,10 +32,12 @@ from ..utils import ( ...@@ -32,10 +32,12 @@ from ..utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
_add_variant, _add_variant,
_get_model_file, _get_model_file,
check_peft_version,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_torch_version, is_torch_version,
...@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None _keys_to_ignore_on_load_unexpected = None
_hf_peft_config_loaded = False
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -292,6 +295,153 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -292,6 +295,153 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
""" """
self.set_use_memory_efficient_attention_xformers(False) self.set_use_memory_efficient_attention_xformers(False)
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
r"""
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
to the adapter to follow the convention of the PEFT library.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
[documentation](https://huggingface.co/docs/peft).
Args:
adapter_config (`[~peft.PeftConfig]`):
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
methods.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
from peft import PeftConfig, inject_adapter_in_model
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
if not isinstance(adapter_config, PeftConfig):
raise ValueError(
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
adapter_config.base_model_name_or_path = None
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Args:
adapter_name (Union[str, List[str]])):
The list of adapters to set or the adapter name in case of single adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
if isinstance(adapter_name, str):
adapter_name = [adapter_name]
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
from peft.tuners.tuners_utils import BaseTunerLayer
_adapters_has_been_set = False
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
# Previous versions of PEFT does not support multi-adapter inference
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
raise ValueError(
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True
if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)
def disable_adapters(self) -> None:
r"""
Disable all adapters attached to the model and fallback to inference with the base model only.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
# support for older PEFT versions
module.disable_adapters = True
def enable_adapters(self) -> None:
"""
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
list of adapters to enable.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
# support for older PEFT versions
module.disable_adapters = False
def active_adapters(self) -> List[str]:
"""
Gets the current list of active adapters of the model.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation from .activations import get_activation
from .attention import AdaGroupNorm from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm from .attention_processor import SpatialNorm
...@@ -149,12 +150,13 @@ class Upsample2D(nn.Module): ...@@ -149,12 +150,13 @@ class Upsample2D(nn.Module):
self.use_conv = use_conv self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv = None conv = None
if use_conv_transpose: if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
...@@ -193,12 +195,12 @@ class Upsample2D(nn.Module): ...@@ -193,12 +195,12 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv: if self.use_conv:
if self.name == "conv": if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv): if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale) hidden_states = self.conv(hidden_states, scale)
else: else:
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
else: else:
if isinstance(self.Conv2d_0, LoRACompatibleConv): if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale) hidden_states = self.Conv2d_0(hidden_states, scale)
else: else:
hidden_states = self.Conv2d_0(hidden_states) hidden_states = self.Conv2d_0(hidden_states)
...@@ -237,9 +239,10 @@ class Downsample2D(nn.Module): ...@@ -237,9 +239,10 @@ class Downsample2D(nn.Module):
self.padding = padding self.padding = padding
stride = 2 stride = 2
self.name = name self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if use_conv: if use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride) conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
...@@ -255,13 +258,18 @@ class Downsample2D(nn.Module): ...@@ -255,13 +258,18 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale) if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else: else:
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
...@@ -608,6 +616,9 @@ class ResnetBlock2D(nn.Module): ...@@ -608,6 +616,9 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act self.skip_time_act = skip_time_act
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
...@@ -618,13 +629,13 @@ class ResnetBlock2D(nn.Module): ...@@ -618,13 +629,13 @@ class ResnetBlock2D(nn.Module):
else: else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None: if temb_channels is not None:
if self.time_embedding_norm == "default": if self.time_embedding_norm == "default":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift": elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None self.time_emb_proj = None
else: else:
...@@ -641,7 +652,7 @@ class ResnetBlock2D(nn.Module): ...@@ -641,7 +652,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity) self.nonlinearity = get_activation(non_linearity)
...@@ -667,7 +678,7 @@ class ResnetBlock2D(nn.Module): ...@@ -667,7 +678,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None self.conv_shortcut = None
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = LoRACompatibleConv( self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
) )
...@@ -708,12 +719,16 @@ class ResnetBlock2D(nn.Module): ...@@ -708,12 +719,16 @@ class ResnetBlock2D(nn.Module):
else self.downsample(hidden_states) else self.downsample(hidden_states)
) )
hidden_states = self.conv1(hidden_states, scale) hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
if self.time_emb_proj is not None: if self.time_emb_proj is not None:
if not self.skip_time_act: if not self.skip_time_act:
temb = self.nonlinearity(temb) temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb, scale)[:, :, None, None] temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
if not USE_PEFT_BACKEND
else self.time_emb_proj(temb)[:, :, None, None]
)
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
...@@ -730,10 +745,12 @@ class ResnetBlock2D(nn.Module): ...@@ -730,10 +745,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, scale) input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
...@@ -20,7 +20,7 @@ from torch import nn ...@@ -20,7 +20,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear from .lora import LoRACompatibleConv, LoRACompatibleLinear
...@@ -100,6 +100,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -100,6 +100,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration # Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_continuous = (in_channels is not None) and (patch_size is None)
...@@ -139,9 +142,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -139,9 +142,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection: if use_linear_projection:
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) self.proj_in = linear_cls(in_channels, inner_dim)
else: else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
...@@ -197,9 +200,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -197,9 +200,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
# TODO: should use out_channels for continuous projections # TODO: should use out_channels for continuous projections
if use_linear_projection: if use_linear_projection:
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) self.proj_out = linear_cls(inner_dim, in_channels)
else: else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim) self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
...@@ -292,13 +295,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -292,13 +295,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states, scale=lora_scale) hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else: else:
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states, scale=lora_scale) hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
...@@ -334,9 +345,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -334,9 +345,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states, scale=lora_scale) hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else: else:
hidden_states = self.proj_out(hidden_states, scale=lora_scale) hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual output = hidden_states + residual
......
...@@ -20,7 +20,7 @@ import torch.utils.checkpoint ...@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from .activations import get_activation from .activations import get_activation
from .attention_processor import ( from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
...@@ -995,6 +995,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -995,6 +995,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
...@@ -1094,6 +1097,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -1094,6 +1097,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
if not return_dict: if not return_dict:
return (sample,) return (sample,)
......
...@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -304,7 +311,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -304,7 +311,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -432,7 +439,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -432,7 +439,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
...@@ -668,6 +675,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -668,6 +675,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
...@@ -689,9 +697,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -689,9 +697,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
...@@ -700,7 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -700,7 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
...@@ -29,6 +29,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -29,6 +29,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate, deprecate,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -309,7 +310,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -309,7 +310,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -437,7 +438,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -437,7 +438,7 @@ class AltDiffusionImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -27,7 +27,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -27,7 +27,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
...@@ -287,7 +294,7 @@ class StableDiffusionControlNetPipeline( ...@@ -287,7 +294,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -415,7 +422,7 @@ class StableDiffusionControlNetPipeline( ...@@ -415,7 +422,7 @@ class StableDiffusionControlNetPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel ...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND,
deprecate, deprecate,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -317,7 +318,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -317,7 +318,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -445,7 +446,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -445,7 +446,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -28,7 +28,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -28,7 +28,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
...@@ -438,7 +445,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -438,7 +445,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -566,7 +573,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -566,7 +573,7 @@ class StableDiffusionControlNetInpaintPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -33,6 +33,7 @@ from ...models.attention_processor import ( ...@@ -33,6 +33,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -316,7 +317,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -316,7 +317,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -458,7 +459,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -458,7 +459,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -35,7 +35,7 @@ from ...models.attention_processor import ( ...@@ -35,7 +35,7 @@ from ...models.attention_processor import (
) )
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -285,7 +285,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -285,7 +285,7 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -427,7 +427,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -427,7 +427,7 @@ class StableDiffusionXLControlNetPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -36,6 +36,7 @@ from ...models.attention_processor import ( ...@@ -36,6 +36,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -328,7 +329,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -328,7 +329,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -470,7 +471,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -470,7 +471,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput from .pipeline_output import StableDiffusionPipelineOutput
...@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -436,7 +436,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -436,7 +436,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -25,7 +25,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -25,7 +25,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput from .pipeline_output import StableDiffusionPipelineOutput
...@@ -297,7 +304,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -297,7 +304,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -425,7 +432,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -425,7 +432,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
...@@ -658,6 +665,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -658,6 +665,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
...@@ -679,9 +687,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -679,9 +687,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
...@@ -690,7 +697,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -690,7 +697,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
...@@ -27,7 +27,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel ...@@ -27,7 +27,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
...@@ -332,7 +339,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -332,7 +339,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -460,7 +467,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -460,7 +467,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -28,7 +28,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -28,7 +28,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -341,7 +341,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -341,7 +341,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment