Unverified Commit 2bfa55f4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `PEFT` / `LoRA`] Integrate PEFT into Unet (#5151)



* v1

* add tests and fix previous failing tests

* fix CI

* add tests + v1 `PeftLayerScaler`

* style

* add scale retrieving mechanism system

* fix CI

* up

* up

* simple approach --> not same results for some reason

* fix issues

* fix copies

* remove unneeded method

* active adapters!

* fix merge conflicts

* up

* up

* kohya - test-1

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix scale

* fix copies

* add comment

* multi adapters

* fix tests

* oops

* v1 faster loading - in progress

* Revert "v1 faster loading - in progress"

This reverts commit ac925f81321e95fc8168184c3346bf3d75404d5a.

* kohya same generation

* fix some slow tests

* peft integration features for unet lora

1. Support for Multiple ranks/alphas
2. Support for Multiple active adapters
3. Support for enabling/disabling LoRAs

* fix `get_peft_kwargs`

* Update loaders.py

* add some tests

* add unfuse tests

* fix tests

* up

* add set adapter from sourab and tests

* fix multi adapter tests

* style & quality

* style

* remove comment

* fix `adapter_name` issues

* fix unet adapter name for sdxl

* fix enabling/disabling adapters

* fix fuse / unfuse unet

* nit

* fix

* up

* fix cpu offloading

* fix another slow test

* fix another offload test

* add more tests

* all slow tests pass

* style

* fix alpha pattern for unet and text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* up

* up

* clarify comment

* comments

* change comment order

* change comment order

* stylr & quality

* Update tests/lora/test_lora_layers_peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix bugs and add tests

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* refactor

* suggestion

* add break statemebt

* add compile tests

* move slow tests to peft tests as I modified them

* quality

* refactor a bit

* style

* change import

* style

* fix CI

* refactor slow tests one last time

* style

* oops

* oops

* oops

* final tweak tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* comments

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* remove comments

* more comments

* try

* revert

* add `safe_merge` tests

* add comment

* style, comments and run tests in fp16

* add warnings

* fix doc test

* replace with `adapter_weights`

* add `get_active_adapters()`

* expose `get_list_adapters` method

* better error message

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* trigger slow lora tests

* fix tests

* maybe fix last test

* revert

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* move `MIN_PEFT_VERSION`

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* let's not use class variable

* fix few nits

* change a bit offloading logic

* check earlier

* rm unneeded block

* break long line

* return empty list

* change logic a bit and address comments

* add typehint

* remove parenthesis

* fix

* revert to fp16 in tests

* add to gpu

* revert to old test

* style

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* change indent

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 9bc55e8b
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import re
from collections import defaultdict
......@@ -32,15 +31,16 @@ from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_i
from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
USE_PEFT_BACKEND,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
deprecate,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
......@@ -72,19 +72,6 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
......@@ -413,7 +400,7 @@ class UNet2DConditionLoadersMixin:
# fill attn processors
lora_layers_list = []
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
if is_lora:
......@@ -527,6 +514,10 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)
elif USE_PEFT_BACKEND:
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
# on the Unet
pass
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
......@@ -537,33 +528,36 @@ class UNet2DConditionLoadersMixin:
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
if not USE_PEFT_BACKEND:
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
......@@ -686,15 +680,77 @@ class UNet2DConditionLoadersMixin:
self.apply(self._fuse_lora_apply)
def _fuse_lora_apply(self, module):
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale, self._safe_fusing)
if not USE_PEFT_BACKEND:
if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale, self._safe_fusing)
else:
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
module.merge(safe_merge=self._safe_fusing)
def unfuse_lora(self):
self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module):
if hasattr(module, "_unfuse_lora"):
module._unfuse_lora()
if not USE_PEFT_BACKEND:
if hasattr(module, "_unfuse_lora"):
module._unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
module.unmerge()
def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float]] = None,
):
"""
Sets the adapter layers for the unet.
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `set_adapters()`.")
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)
set_weights_and_activate_adapters(self, adapter_names, weights)
def disable_lora(self):
"""
Disables the active LoRA layers for the unet.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=False)
def enable_lora(self):
"""
Enables the active LoRA layers for the unet.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=True)
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
......@@ -1113,7 +1169,6 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
......@@ -1155,6 +1210,7 @@ class LoraLoaderMixin:
network_alphas=network_alphas,
unet=self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self,
)
self.load_lora_into_text_encoder(
......@@ -1464,7 +1520,40 @@ class LoraLoaderMixin:
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
......@@ -1482,6 +1571,9 @@ class LoraLoaderMixin:
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
"""
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
......@@ -1508,6 +1600,56 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message)
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(unet, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
state_dict = convert_unet_state_dict_to_peft(state_dict)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(unet)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
......@@ -1570,7 +1712,7 @@ class LoraLoaderMixin:
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if cls.use_peft_backend:
if USE_PEFT_BACKEND:
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
......@@ -1583,6 +1725,7 @@ class LoraLoaderMixin:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
......@@ -1606,10 +1749,12 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
if cls.use_peft_backend:
if USE_PEFT_BACKEND:
from peft import LoraConfig
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict)
lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)
lora_config = LoraConfig(**lora_config_kwargs)
......@@ -1617,17 +1762,18 @@ class LoraLoaderMixin:
if adapter_name is None:
adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter(
adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config,
)
# scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
else:
cls._modify_text_encoder(
text_encoder,
......@@ -1699,7 +1845,7 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend:
if USE_PEFT_BACKEND:
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
......@@ -1707,12 +1853,13 @@ class LoraLoaderMixin:
if hasattr(self, "text_encoder"):
remove_method(self.text_encoder)
if self.use_peft_backend:
# In case text encoder have no Lora attached
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2)
if self.use_peft_backend:
if USE_PEFT_BACKEND:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
......@@ -2088,9 +2235,20 @@ class LoraLoaderMixin:
>>> ...
```
"""
for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"):
logger.warn(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)
for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)
else:
recurse_remove_peft_layers(self.unet)
if hasattr(self.unet, "peft_config"):
del self.unet.peft_config
# Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch()
......@@ -2131,7 +2289,7 @@ class LoraLoaderMixin:
if fuse_unet:
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
if self.use_peft_backend:
if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
......@@ -2184,9 +2342,16 @@ class LoraLoaderMixin:
LoRA parameters then it won't have any effect.
"""
if unfuse_unet:
self.unet.unfuse_lora()
if not USE_PEFT_BACKEND:
self.unet.unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
if self.use_peft_backend:
for module in self.unet.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder):
......@@ -2219,7 +2384,7 @@ class LoraLoaderMixin:
self.num_fused_loras -= 1
def set_adapter_for_text_encoder(
def set_adapters_for_text_encoder(
self,
adapter_names: Union[List[str], str],
text_encoder: Optional[PreTrainedModel] = None,
......@@ -2237,7 +2402,7 @@ class LoraLoaderMixin:
text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
"""
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
def process_weights(adapter_names, weights):
......@@ -2270,7 +2435,7 @@ class LoraLoaderMixin:
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
`text_encoder` attribute.
"""
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None)
......@@ -2287,13 +2452,146 @@ class LoraLoaderMixin:
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute.
"""
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None:
raise ValueError("Text Encoder not found.")
set_adapter_layers(self.text_encoder, enabled=True)
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
):
# Handle the UNET
self.unet.set_adapters(adapter_names, adapter_weights)
# Handle the Text Encoder
if hasattr(self, "text_encoder"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
if hasattr(self, "text_encoder_2"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
def disable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# Disable unet adapters
self.unet.disable_lora()
# Disable text encoder adapters
if hasattr(self, "text_encoder"):
self.disable_lora_for_text_encoder(self.text_encoder)
if hasattr(self, "text_encoder_2"):
self.disable_lora_for_text_encoder(self.text_encoder_2)
def enable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# Enable unet adapters
self.unet.enable_lora()
# Enable text encoder adapters
if hasattr(self, "text_encoder"):
self.enable_lora_for_text_encoder(self.text_encoder)
if hasattr(self, "text_encoder_2"):
self.enable_lora_for_text_encoder(self.text_encoder_2)
def get_active_adapters(self) -> List[str]:
"""
Gets the list of the current active adapters.
Example:
```python
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
).to("cuda")
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipeline.get_active_adapters()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
from peft.tuners.tuners_utils import BaseTunerLayer
active_adapters = []
for module in self.unet.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
break
return active_adapters
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
Gets the current list of all available adapters in the pipeline.
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
set_adapters = {}
if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"):
set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys())
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
set_adapters["unet"] = list(self.unet.peft_config.keys())
return set_adapters
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
device (`Union[torch.device, str, int]`):
Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft.tuners.tuners_utils import BaseTunerLayer
# Handle the UNET
for unet_module in self.unet.modules():
if isinstance(unet_module, BaseTunerLayer):
for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device)
# Handle the text encoder
modules_to_process = []
if hasattr(self, "text_encoder"):
modules_to_process.append(self.text_encoder)
if hasattr(self, "text_encoder_2"):
modules_to_process.append(self.text_encoder_2)
for text_encoder in modules_to_process:
# loop over submodules
for text_encoder_module in text_encoder.modules():
if isinstance(text_encoder_module, BaseTunerLayer):
for adapter_name in adapter_names:
text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device)
class FromSingleFileMixin:
"""
......@@ -2878,7 +3176,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
**kwargs,
):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
......@@ -2896,6 +3199,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
......@@ -2913,7 +3219,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
......@@ -2922,6 +3230,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
)
......@@ -2933,6 +3242,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
)
......@@ -2999,16 +3309,17 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
)
def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend:
if USE_PEFT_BACKEND:
recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2)
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
if getattr(self.text_encoder_2, "peft_config", None) is not None:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
......@@ -17,6 +17,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph
from .activations import get_activation
from .attention_processor import Attention
......@@ -300,6 +301,7 @@ class FeedForward(nn.Module):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
......@@ -316,14 +318,15 @@ class FeedForward(nn.Module):
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out))
self.net.append(linear_cls(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)):
if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale)
else:
hidden_states = module(hidden_states)
......@@ -368,7 +371,9 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2)
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2)
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
......@@ -377,7 +382,8 @@ class GEGLU(nn.Module):
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1)
args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
......
......@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F
from torch import nn
from ..utils import deprecate, logging
from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer
......@@ -137,22 +137,27 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
)
self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias)
if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
else:
self.to_k = None
self.to_v = None
if self.added_kv_proj_dim is not None:
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim)
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
# set attention processor
......@@ -545,6 +550,8 @@ class AttnProcessor:
):
residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
......@@ -562,15 +569,15 @@ class AttnProcessor:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
......@@ -581,7 +588,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = attn.to_out[0](hidden_states, *args)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......@@ -1007,15 +1014,20 @@ class AttnProcessor2_0:
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale)
args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale)
value = attn.to_v(encoder_hidden_states, scale=scale)
key = (
attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
)
value = (
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
......@@ -1035,7 +1047,9 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale)
hidden_states = (
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
)
# dropout
hidden_states = attn.to_out[1](hidden_states)
......
......@@ -18,6 +18,7 @@ import numpy as np
import torch
from torch import nn
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .lora import LoRACompatibleLinear
......@@ -166,8 +167,9 @@ class TimestepEmbedding(nn.Module):
cond_proj_dim=None,
):
super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim)
self.linear_1 = linear_cls(in_channels, time_embed_dim)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
......@@ -180,7 +182,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out)
self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
if post_act_fn is None:
self.post_act = None
......
......@@ -32,10 +32,12 @@ from ..utils import (
DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE,
MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME,
_add_variant,
_get_model_file,
check_peft_version,
deprecate,
is_accelerate_available,
is_torch_version,
......@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None
_hf_peft_config_loaded = False
def __init__(self):
super().__init__()
......@@ -292,6 +295,153 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
r"""
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
to the adapter to follow the convention of the PEFT library.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
[documentation](https://huggingface.co/docs/peft).
Args:
adapter_config (`[~peft.PeftConfig]`):
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
methods.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
from peft import PeftConfig, inject_adapter_in_model
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
if not isinstance(adapter_config, PeftConfig):
raise ValueError(
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
adapter_config.base_model_name_or_path = None
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Args:
adapter_name (Union[str, List[str]])):
The list of adapters to set or the adapter name in case of single adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
if isinstance(adapter_name, str):
adapter_name = [adapter_name]
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
from peft.tuners.tuners_utils import BaseTunerLayer
_adapters_has_been_set = False
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
# Previous versions of PEFT does not support multi-adapter inference
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
raise ValueError(
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True
if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)
def disable_adapters(self) -> None:
r"""
Disable all adapters attached to the model and fallback to inference with the base model only.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
# support for older PEFT versions
module.disable_adapters = True
def enable_adapters(self) -> None:
"""
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
list of adapters to enable.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
# support for older PEFT versions
module.disable_adapters = False
def active_adapters(self) -> List[str]:
"""
Gets the current list of active adapters of the model.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
......
......@@ -20,6 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation
from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm
......@@ -149,12 +150,13 @@ class Upsample2D(nn.Module):
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv = None
if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1)
conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv":
......@@ -193,12 +195,12 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv):
if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
if isinstance(self.Conv2d_0, LoRACompatibleConv):
if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale)
else:
hidden_states = self.Conv2d_0(hidden_states)
......@@ -237,9 +239,10 @@ class Downsample2D(nn.Module):
self.padding = padding
stride = 2
self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding)
conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
......@@ -255,13 +258,18 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale)
else:
hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
......@@ -608,6 +616,9 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if groups_out is None:
groups_out = groups
......@@ -618,13 +629,13 @@ class ResnetBlock2D(nn.Module):
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels)
self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels)
self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None
else:
......@@ -641,7 +652,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity)
......@@ -667,7 +678,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = LoRACompatibleConv(
self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
)
......@@ -708,12 +719,16 @@ class ResnetBlock2D(nn.Module):
else self.downsample(hidden_states)
)
hidden_states = self.conv1(hidden_states, scale)
hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
if self.time_emb_proj is not None:
if not self.skip_time_act:
temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb, scale)[:, :, None, None]
temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
if not USE_PEFT_BACKEND
else self.time_emb_proj(temb)[:, :, None, None]
)
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
......@@ -730,10 +745,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale)
hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, scale)
input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
......@@ -20,7 +20,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear
......@@ -100,6 +100,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
......@@ -139,9 +142,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
self.proj_in = linear_cls(in_channels, inner_dim)
else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
......@@ -197,9 +200,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
self.proj_out = linear_cls(inner_dim, in_channels)
else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
......@@ -292,13 +295,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
......@@ -334,9 +345,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = self.proj_out(hidden_states, scale=lora_scale)
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
......
......@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from .activations import get_activation
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
......@@ -995,6 +995,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
......@@ -1094,6 +1097,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
if not return_dict:
return (sample,)
......
......@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
......@@ -304,7 +311,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -432,7 +439,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......@@ -668,6 +675,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
self.check_inputs(
......@@ -689,9 +697,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
......@@ -700,7 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
......
......@@ -29,6 +29,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
......@@ -309,7 +310,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -437,7 +438,7 @@ class AltDiffusionImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -27,7 +27,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
......@@ -287,7 +294,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -415,7 +422,7 @@ class StableDiffusionControlNetPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
......@@ -317,7 +318,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -445,7 +446,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -28,7 +28,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput
......@@ -438,7 +445,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -566,7 +573,7 @@ class StableDiffusionControlNetInpaintPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -33,6 +33,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
is_invisible_watermark_available,
logging,
replace_example_docstring,
......@@ -316,7 +317,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
......@@ -458,7 +459,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
bs_embed * num_images_per_prompt, -1
)
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
......
......@@ -35,7 +35,7 @@ from ...models.attention_processor import (
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
......@@ -285,7 +285,7 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
......@@ -427,7 +427,7 @@ class StableDiffusionXLControlNetPipeline(
bs_embed * num_images_per_prompt, -1
)
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
......
......@@ -36,6 +36,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
......@@ -328,7 +329,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
......@@ -470,7 +471,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
bs_embed * num_images_per_prompt, -1
)
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
......
......@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput
......@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -436,7 +436,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -25,7 +25,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput
......@@ -297,7 +304,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -425,7 +432,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......@@ -658,6 +665,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
self.check_inputs(
......@@ -679,9 +687,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
......@@ -690,7 +697,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
lora_scale=lora_scale,
clip_skip=clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
......
......@@ -27,7 +27,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
......@@ -332,7 +339,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -460,7 +467,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -28,7 +28,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -341,7 +341,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment