Unverified Commit 2bfa55f4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `PEFT` / `LoRA`] Integrate PEFT into Unet (#5151)



* v1

* add tests and fix previous failing tests

* fix CI

* add tests + v1 `PeftLayerScaler`

* style

* add scale retrieving mechanism system

* fix CI

* up

* up

* simple approach --> not same results for some reason

* fix issues

* fix copies

* remove unneeded method

* active adapters!

* fix merge conflicts

* up

* up

* kohya - test-1

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix scale

* fix copies

* add comment

* multi adapters

* fix tests

* oops

* v1 faster loading - in progress

* Revert "v1 faster loading - in progress"

This reverts commit ac925f81321e95fc8168184c3346bf3d75404d5a.

* kohya same generation

* fix some slow tests

* peft integration features for unet lora

1. Support for Multiple ranks/alphas
2. Support for Multiple active adapters
3. Support for enabling/disabling LoRAs

* fix `get_peft_kwargs`

* Update loaders.py

* add some tests

* add unfuse tests

* fix tests

* up

* add set adapter from sourab and tests

* fix multi adapter tests

* style & quality

* style

* remove comment

* fix `adapter_name` issues

* fix unet adapter name for sdxl

* fix enabling/disabling adapters

* fix fuse / unfuse unet

* nit

* fix

* up

* fix cpu offloading

* fix another slow test

* fix another offload test

* add more tests

* all slow tests pass

* style

* fix alpha pattern for unet and text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* up

* up

* clarify comment

* comments

* change comment order

* change comment order

* stylr & quality

* Update tests/lora/test_lora_layers_peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix bugs and add tests

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* refactor

* suggestion

* add break statemebt

* add compile tests

* move slow tests to peft tests as I modified them

* quality

* refactor a bit

* style

* change import

* style

* fix CI

* refactor slow tests one last time

* style

* oops

* oops

* oops

* final tweak tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* comments

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* remove comments

* more comments

* try

* revert

* add `safe_merge` tests

* add comment

* style, comments and run tests in fp16

* add warnings

* fix doc test

* replace with `adapter_weights`

* add `get_active_adapters()`

* expose `get_list_adapters` method

* better error message

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* trigger slow lora tests

* fix tests

* maybe fix last test

* revert

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* move `MIN_PEFT_VERSION`

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* let's not use class variable

* fix few nits

* change a bit offloading logic

* check earlier

* rm unneeded block

* break long line

* return empty list

* change logic a bit and address comments

* add typehint

* remove parenthesis

* fix

* revert to fp16 in tests

* add to gpu

* revert to old test

* style

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* change indent

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 9bc55e8b
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
...@@ -32,15 +31,16 @@ from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_i ...@@ -32,15 +31,16 @@ from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_i
from .utils import ( from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
USE_PEFT_BACKEND,
_get_model_file, _get_model_file,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_peft, convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
deprecate, deprecate,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
is_omegaconf_available, is_omegaconf_available,
is_peft_available,
is_transformers_available, is_transformers_available,
logging, logging,
recurse_remove_peft_layers, recurse_remove_peft_layers,
...@@ -72,19 +72,6 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors" ...@@ -72,19 +72,6 @@ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future." LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
...@@ -413,7 +400,7 @@ class UNet2DConditionLoadersMixin: ...@@ -413,7 +400,7 @@ class UNet2DConditionLoadersMixin:
# fill attn processors # fill attn processors
lora_layers_list = [] lora_layers_list = []
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
if is_lora: if is_lora:
...@@ -527,6 +514,10 @@ class UNet2DConditionLoadersMixin: ...@@ -527,6 +514,10 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
attn_processors[key].load_state_dict(value_dict) attn_processors[key].load_state_dict(value_dict)
elif USE_PEFT_BACKEND:
# In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
# on the Unet
pass
else: else:
raise ValueError( raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
...@@ -537,12 +528,15 @@ class UNet2DConditionLoadersMixin: ...@@ -537,12 +528,15 @@ class UNet2DConditionLoadersMixin:
# Now we remove any existing hooks to # Now we remove any existing hooks to
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
# For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
if not USE_PEFT_BACKEND:
if _pipeline is not None: if _pipeline is not None:
for _, component in _pipeline.components.items(): for _, component in _pipeline.components.items():
if isinstance(component, nn.Module): if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook) is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info( logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
) )
...@@ -686,15 +680,77 @@ class UNet2DConditionLoadersMixin: ...@@ -686,15 +680,77 @@ class UNet2DConditionLoadersMixin:
self.apply(self._fuse_lora_apply) self.apply(self._fuse_lora_apply)
def _fuse_lora_apply(self, module): def _fuse_lora_apply(self, module):
if not USE_PEFT_BACKEND:
if hasattr(module, "_fuse_lora"): if hasattr(module, "_fuse_lora"):
module._fuse_lora(self.lora_scale, self._safe_fusing) module._fuse_lora(self.lora_scale, self._safe_fusing)
else:
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
if self.lora_scale != 1.0:
module.scale_layer(self.lora_scale)
module.merge(safe_merge=self._safe_fusing)
def unfuse_lora(self): def unfuse_lora(self):
self.apply(self._unfuse_lora_apply) self.apply(self._unfuse_lora_apply)
def _unfuse_lora_apply(self, module): def _unfuse_lora_apply(self, module):
if not USE_PEFT_BACKEND:
if hasattr(module, "_unfuse_lora"): if hasattr(module, "_unfuse_lora"):
module._unfuse_lora() module._unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(module, BaseTunerLayer):
module.unmerge()
def set_adapters(
self,
adapter_names: Union[List[str], str],
weights: Optional[Union[List[float], float]] = None,
):
"""
Sets the adapter layers for the unet.
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for `set_adapters()`.")
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
if weights is None:
weights = [1.0] * len(adapter_names)
elif isinstance(weights, float):
weights = [weights] * len(adapter_names)
if len(adapter_names) != len(weights):
raise ValueError(
f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
)
set_weights_and_activate_adapters(self, adapter_names, weights)
def disable_lora(self):
"""
Disables the active LoRA layers for the unet.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=False)
def enable_lora(self):
"""
Enables the active LoRA layers for the unet.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
set_adapter_layers(self, enabled=True)
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs): def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
...@@ -1113,7 +1169,6 @@ class LoraLoaderMixin: ...@@ -1113,7 +1169,6 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME unet_name = UNET_NAME
num_fused_loras = 0 num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights( def load_lora_weights(
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name=None, **kwargs
...@@ -1155,6 +1210,7 @@ class LoraLoaderMixin: ...@@ -1155,6 +1210,7 @@ class LoraLoaderMixin:
network_alphas=network_alphas, network_alphas=network_alphas,
unet=self.unet, unet=self.unet,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
...@@ -1464,7 +1520,40 @@ class LoraLoaderMixin: ...@@ -1464,7 +1520,40 @@ class LoraLoaderMixin:
return new_state_dict return new_state_dict
@classmethod @classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None): def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = isinstance(component._hf_hook, AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod
def load_lora_into_unet(
cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, adapter_name=None, _pipeline=None
):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
...@@ -1482,6 +1571,9 @@ class LoraLoaderMixin: ...@@ -1482,6 +1571,9 @@ class LoraLoaderMixin:
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error. argument to `True` will raise an error.
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
""" """
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
...@@ -1508,6 +1600,56 @@ class LoraLoaderMixin: ...@@ -1508,6 +1600,56 @@ class LoraLoaderMixin:
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
logger.warn(warn_message) logger.warn(warn_message)
if USE_PEFT_BACKEND and len(state_dict.keys()) > 0:
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
if adapter_name in getattr(unet, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
)
state_dict = convert_unet_state_dict_to_peft(state_dict)
if network_alphas is not None:
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
# `convert_unet_state_dict_to_peft` method.
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(unet)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, unet, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(unet, state_dict, adapter_name)
if incompatible_keys is not None:
# check only for unexpected keys
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
logger.warning(
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
f" {unexpected_keys}. "
)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
unet.load_attn_procs( unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
) )
...@@ -1570,7 +1712,7 @@ class LoraLoaderMixin: ...@@ -1570,7 +1712,7 @@ class LoraLoaderMixin:
rank = {} rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict) text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if cls.use_peft_backend: if USE_PEFT_BACKEND:
# convert state dict # convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict) text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
...@@ -1583,6 +1725,7 @@ class LoraLoaderMixin: ...@@ -1583,6 +1725,7 @@ class LoraLoaderMixin:
for name, _ in text_encoder_mlp_modules(text_encoder): for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight" rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight" rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1] rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1] rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else: else:
...@@ -1606,10 +1749,12 @@ class LoraLoaderMixin: ...@@ -1606,10 +1749,12 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
if cls.use_peft_backend: if USE_PEFT_BACKEND:
from peft import LoraConfig from peft import LoraConfig
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, text_encoder_lora_state_dict) lora_config_kwargs = get_peft_kwargs(
rank, network_alphas, text_encoder_lora_state_dict, is_unet=False
)
lora_config = LoraConfig(**lora_config_kwargs) lora_config = LoraConfig(**lora_config_kwargs)
...@@ -1617,17 +1762,18 @@ class LoraLoaderMixin: ...@@ -1617,17 +1762,18 @@ class LoraLoaderMixin:
if adapter_name is None: if adapter_name is None:
adapter_name = get_adapter_name(text_encoder) adapter_name = get_adapter_name(text_encoder)
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
# inject LoRA layers and load the state dict # inject LoRA layers and load the state dict
# in transformers we automatically check whether the adapter name is already in use or not
text_encoder.load_adapter( text_encoder.load_adapter(
adapter_name=adapter_name, adapter_name=adapter_name,
adapter_state_dict=text_encoder_lora_state_dict, adapter_state_dict=text_encoder_lora_state_dict,
peft_config=lora_config, peft_config=lora_config,
) )
# scale LoRA layers with `lora_scale` # scale LoRA layers with `lora_scale`
scale_lora_layers(text_encoder, weight=lora_scale) scale_lora_layers(text_encoder, weight=lora_scale)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
else: else:
cls._modify_text_encoder( cls._modify_text_encoder(
text_encoder, text_encoder,
...@@ -1699,7 +1845,7 @@ class LoraLoaderMixin: ...@@ -1699,7 +1845,7 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend: if USE_PEFT_BACKEND:
remove_method = recurse_remove_peft_layers remove_method = recurse_remove_peft_layers
else: else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod remove_method = self._remove_text_encoder_monkey_patch_classmethod
...@@ -1707,12 +1853,13 @@ class LoraLoaderMixin: ...@@ -1707,12 +1853,13 @@ class LoraLoaderMixin:
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
remove_method(self.text_encoder) remove_method(self.text_encoder)
if self.use_peft_backend: # In case text encoder have no Lora attached
if USE_PEFT_BACKEND and getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"): if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2) remove_method(self.text_encoder_2)
if self.use_peft_backend: if USE_PEFT_BACKEND:
del self.text_encoder_2.peft_config del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None self.text_encoder_2._hf_peft_config_loaded = None
...@@ -2088,9 +2235,20 @@ class LoraLoaderMixin: ...@@ -2088,9 +2235,20 @@ class LoraLoaderMixin:
>>> ... >>> ...
``` ```
""" """
if not USE_PEFT_BACKEND:
if version.parse(__version__) > version.parse("0.23"):
logger.warn(
"You are using `unload_lora_weights` to disable and unload lora weights. If you want to iteratively enable and disable adapter weights,"
"you can use `pipe.enable_lora()` or `pipe.disable_lora()`. After installing the latest version of PEFT."
)
for _, module in self.unet.named_modules(): for _, module in self.unet.named_modules():
if hasattr(module, "set_lora_layer"): if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None) module.set_lora_layer(None)
else:
recurse_remove_peft_layers(self.unet)
if hasattr(self.unet, "peft_config"):
del self.unet.peft_config
# Safe to call the following regardless of LoRA. # Safe to call the following regardless of LoRA.
self._remove_text_encoder_monkey_patch() self._remove_text_encoder_monkey_patch()
...@@ -2131,7 +2289,7 @@ class LoraLoaderMixin: ...@@ -2131,7 +2289,7 @@ class LoraLoaderMixin:
if fuse_unet: if fuse_unet:
self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing) self.unet.fuse_lora(lora_scale, safe_fusing=safe_fusing)
if self.use_peft_backend: if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False): def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False):
...@@ -2184,9 +2342,16 @@ class LoraLoaderMixin: ...@@ -2184,9 +2342,16 @@ class LoraLoaderMixin:
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
""" """
if unfuse_unet: if unfuse_unet:
if not USE_PEFT_BACKEND:
self.unet.unfuse_lora() self.unet.unfuse_lora()
else:
from peft.tuners.tuners_utils import BaseTunerLayer
for module in self.unet.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
if self.use_peft_backend: if USE_PEFT_BACKEND:
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder): def unfuse_text_encoder_lora(text_encoder):
...@@ -2219,7 +2384,7 @@ class LoraLoaderMixin: ...@@ -2219,7 +2384,7 @@ class LoraLoaderMixin:
self.num_fused_loras -= 1 self.num_fused_loras -= 1
def set_adapter_for_text_encoder( def set_adapters_for_text_encoder(
self, self,
adapter_names: Union[List[str], str], adapter_names: Union[List[str], str],
text_encoder: Optional[PreTrainedModel] = None, text_encoder: Optional[PreTrainedModel] = None,
...@@ -2237,7 +2402,7 @@ class LoraLoaderMixin: ...@@ -2237,7 +2402,7 @@ class LoraLoaderMixin:
text_encoder_weights (`List[float]`, *optional*): text_encoder_weights (`List[float]`, *optional*):
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters. The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
""" """
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
def process_weights(adapter_names, weights): def process_weights(adapter_names, weights):
...@@ -2270,7 +2435,7 @@ class LoraLoaderMixin: ...@@ -2270,7 +2435,7 @@ class LoraLoaderMixin:
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the The text encoder module to disable the LoRA layers for. If `None`, it will try to get the
`text_encoder` attribute. `text_encoder` attribute.
""" """
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None) text_encoder = text_encoder or getattr(self, "text_encoder", None)
...@@ -2287,13 +2452,146 @@ class LoraLoaderMixin: ...@@ -2287,13 +2452,146 @@ class LoraLoaderMixin:
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder` The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
attribute. attribute.
""" """
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
text_encoder = text_encoder or getattr(self, "text_encoder", None) text_encoder = text_encoder or getattr(self, "text_encoder", None)
if text_encoder is None: if text_encoder is None:
raise ValueError("Text Encoder not found.") raise ValueError("Text Encoder not found.")
set_adapter_layers(self.text_encoder, enabled=True) set_adapter_layers(self.text_encoder, enabled=True)
def set_adapters(
self,
adapter_names: Union[List[str], str],
adapter_weights: Optional[List[float]] = None,
):
# Handle the UNET
self.unet.set_adapters(adapter_names, adapter_weights)
# Handle the Text Encoder
if hasattr(self, "text_encoder"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder, adapter_weights)
if hasattr(self, "text_encoder_2"):
self.set_adapters_for_text_encoder(adapter_names, self.text_encoder_2, adapter_weights)
def disable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# Disable unet adapters
self.unet.disable_lora()
# Disable text encoder adapters
if hasattr(self, "text_encoder"):
self.disable_lora_for_text_encoder(self.text_encoder)
if hasattr(self, "text_encoder_2"):
self.disable_lora_for_text_encoder(self.text_encoder_2)
def enable_lora(self):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
# Enable unet adapters
self.unet.enable_lora()
# Enable text encoder adapters
if hasattr(self, "text_encoder"):
self.enable_lora_for_text_encoder(self.text_encoder)
if hasattr(self, "text_encoder_2"):
self.enable_lora_for_text_encoder(self.text_encoder_2)
def get_active_adapters(self) -> List[str]:
"""
Gets the list of the current active adapters.
Example:
```python
from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
).to("cuda")
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
pipeline.get_active_adapters()
```
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
from peft.tuners.tuners_utils import BaseTunerLayer
active_adapters = []
for module in self.unet.modules():
if isinstance(module, BaseTunerLayer):
active_adapters = module.active_adapters
break
return active_adapters
def get_list_adapters(self) -> Dict[str, List[str]]:
"""
Gets the current list of all available adapters in the pipeline.
"""
if not USE_PEFT_BACKEND:
raise ValueError(
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
)
set_adapters = {}
if hasattr(self, "text_encoder") and hasattr(self.text_encoder, "peft_config"):
set_adapters["text_encoder"] = list(self.text_encoder.peft_config.keys())
if hasattr(self, "text_encoder_2") and hasattr(self.text_encoder_2, "peft_config"):
set_adapters["text_encoder_2"] = list(self.text_encoder_2.peft_config.keys())
if hasattr(self, "unet") and hasattr(self.unet, "peft_config"):
set_adapters["unet"] = list(self.unet.peft_config.keys())
return set_adapters
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
"""
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
you want to load multiple adapters and free some GPU memory.
Args:
adapter_names (`List[str]`):
List of adapters to send device to.
device (`Union[torch.device, str, int]`):
Device to send the adapters to. Can be either a torch device, a str or an integer.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
from peft.tuners.tuners_utils import BaseTunerLayer
# Handle the UNET
for unet_module in self.unet.modules():
if isinstance(unet_module, BaseTunerLayer):
for adapter_name in adapter_names:
unet_module.lora_A[adapter_name].to(device)
unet_module.lora_B[adapter_name].to(device)
# Handle the text encoder
modules_to_process = []
if hasattr(self, "text_encoder"):
modules_to_process.append(self.text_encoder)
if hasattr(self, "text_encoder_2"):
modules_to_process.append(self.text_encoder_2)
for text_encoder in modules_to_process:
# loop over submodules
for text_encoder_module in text_encoder.modules():
if isinstance(text_encoder_module, BaseTunerLayer):
for adapter_name in adapter_names:
text_encoder_module.lora_A[adapter_name].to(device)
text_encoder_module.lora_B[adapter_name].to(device)
class FromSingleFileMixin: class FromSingleFileMixin:
""" """
...@@ -2878,7 +3176,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2878,7 +3176,12 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL""" """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
# Overrride to properly handle the loading and unloading of the additional text encoder. # Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
**kwargs,
):
""" """
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`. `self.text_encoder`.
...@@ -2896,6 +3199,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2896,6 +3199,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
Parameters: Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
""" """
...@@ -2913,7 +3219,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2913,7 +3219,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self) self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, adapter_name=adapter_name, _pipeline=self
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
...@@ -2922,6 +3230,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2922,6 +3230,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
prefix="text_encoder", prefix="text_encoder",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
...@@ -2933,6 +3242,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2933,6 +3242,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
text_encoder=self.text_encoder_2, text_encoder=self.text_encoder_2,
prefix="text_encoder_2", prefix="text_encoder_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
) )
...@@ -2999,14 +3309,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2999,14 +3309,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
) )
def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch(self):
if self.use_peft_backend: if USE_PEFT_BACKEND:
recurse_remove_peft_layers(self.text_encoder) recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side # TODO: @younesbelkada handle this in transformers side
if getattr(self.text_encoder, "peft_config", None) is not None:
del self.text_encoder.peft_config del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2) recurse_remove_peft_layers(self.text_encoder_2)
if getattr(self.text_encoder_2, "peft_config", None) is not None:
del self.text_encoder_2.peft_config del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None self.text_encoder_2._hf_peft_config_loaded = None
else: else:
......
...@@ -17,6 +17,7 @@ import torch ...@@ -17,6 +17,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .activations import get_activation from .activations import get_activation
from .attention_processor import Attention from .attention_processor import Attention
...@@ -300,6 +301,7 @@ class FeedForward(nn.Module): ...@@ -300,6 +301,7 @@ class FeedForward(nn.Module):
super().__init__() super().__init__()
inner_dim = int(dim * mult) inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim dim_out = dim_out if dim_out is not None else dim
linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
if activation_fn == "gelu": if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim) act_fn = GELU(dim, inner_dim)
...@@ -316,14 +318,15 @@ class FeedForward(nn.Module): ...@@ -316,14 +318,15 @@ class FeedForward(nn.Module):
# project dropout # project dropout
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
# project out # project out
self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) self.net.append(linear_cls(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout: if final_dropout:
self.net.append(nn.Dropout(dropout)) self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
for module in self.net: for module in self.net:
if isinstance(module, (LoRACompatibleLinear, GEGLU)): if isinstance(module, compatible_cls):
hidden_states = module(hidden_states, scale) hidden_states = module(hidden_states, scale)
else: else:
hidden_states = module(hidden_states) hidden_states = module(hidden_states)
...@@ -368,7 +371,9 @@ class GEGLU(nn.Module): ...@@ -368,7 +371,9 @@ class GEGLU(nn.Module):
def __init__(self, dim_in: int, dim_out: int): def __init__(self, dim_in: int, dim_out: int):
super().__init__() super().__init__()
self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
self.proj = linear_cls(dim_in, dim_out * 2)
def gelu(self, gate: torch.Tensor) -> torch.Tensor: def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps": if gate.device.type != "mps":
...@@ -377,7 +382,8 @@ class GEGLU(nn.Module): ...@@ -377,7 +382,8 @@ class GEGLU(nn.Module):
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states, scale: float = 1.0):
hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) args = () if USE_PEFT_BACKEND else (scale,)
hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1)
return hidden_states * self.gelu(gate) return hidden_states * self.gelu(gate)
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from ..utils import deprecate, logging from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from ..utils.torch_utils import maybe_allow_in_graph from ..utils.torch_utils import maybe_allow_in_graph
from .lora import LoRACompatibleLinear, LoRALinearLayer from .lora import LoRACompatibleLinear, LoRALinearLayer
...@@ -137,22 +137,27 @@ class Attention(nn.Module): ...@@ -137,22 +137,27 @@ class Attention(nn.Module):
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
) )
self.to_q = LoRACompatibleLinear(query_dim, self.inner_dim, bias=bias) if USE_PEFT_BACKEND:
linear_cls = nn.Linear
else:
linear_cls = LoRACompatibleLinear
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
if not self.only_cross_attention: if not self.only_cross_attention:
# only relevant for the `AddedKVProcessor` classes # only relevant for the `AddedKVProcessor` classes
self.to_k = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
self.to_v = LoRACompatibleLinear(self.cross_attention_dim, self.inner_dim, bias=bias) self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
else: else:
self.to_k = None self.to_k = None
self.to_v = None self.to_v = None
if self.added_kv_proj_dim is not None: if self.added_kv_proj_dim is not None:
self.add_k_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.add_v_proj = LoRACompatibleLinear(added_kv_proj_dim, self.inner_dim) self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([]) self.to_out = nn.ModuleList([])
self.to_out.append(LoRACompatibleLinear(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout)) self.to_out.append(nn.Dropout(dropout))
# set attention processor # set attention processor
...@@ -545,6 +550,8 @@ class AttnProcessor: ...@@ -545,6 +550,8 @@ class AttnProcessor:
): ):
residual = hidden_states residual = hidden_states
args = () if USE_PEFT_BACKEND else (scale,)
if attn.spatial_norm is not None: if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb) hidden_states = attn.spatial_norm(hidden_states, temb)
...@@ -562,15 +569,15 @@ class AttnProcessor: ...@@ -562,15 +569,15 @@ class AttnProcessor:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale) query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale) key = attn.to_k(encoder_hidden_states, *args)
value = attn.to_v(encoder_hidden_states, scale=scale) value = attn.to_v(encoder_hidden_states, *args)
query = attn.head_to_batch_dim(query) query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key) key = attn.head_to_batch_dim(key)
...@@ -581,7 +588,7 @@ class AttnProcessor: ...@@ -581,7 +588,7 @@ class AttnProcessor:
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale) hidden_states = attn.to_out[0](hidden_states, *args)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
...@@ -1007,15 +1014,20 @@ class AttnProcessor2_0: ...@@ -1007,15 +1014,20 @@ class AttnProcessor2_0:
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states, scale=scale) args = () if USE_PEFT_BACKEND else (scale,)
query = attn.to_q(hidden_states, *args)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_hidden_states = hidden_states encoder_hidden_states = hidden_states
elif attn.norm_cross: elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states, scale=scale) key = (
value = attn.to_v(encoder_hidden_states, scale=scale) attn.to_k(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_k(encoder_hidden_states)
)
value = (
attn.to_v(encoder_hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_v(encoder_hidden_states)
)
inner_dim = key.shape[-1] inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads head_dim = inner_dim // attn.heads
...@@ -1035,7 +1047,9 @@ class AttnProcessor2_0: ...@@ -1035,7 +1047,9 @@ class AttnProcessor2_0:
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
# linear proj # linear proj
hidden_states = attn.to_out[0](hidden_states, scale=scale) hidden_states = (
attn.to_out[0](hidden_states, scale=scale) if not USE_PEFT_BACKEND else attn.to_out[0](hidden_states)
)
# dropout # dropout
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
......
...@@ -18,6 +18,7 @@ import numpy as np ...@@ -18,6 +18,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation from .activations import get_activation
from .lora import LoRACompatibleLinear from .lora import LoRACompatibleLinear
...@@ -166,8 +167,9 @@ class TimestepEmbedding(nn.Module): ...@@ -166,8 +167,9 @@ class TimestepEmbedding(nn.Module):
cond_proj_dim=None, cond_proj_dim=None,
): ):
super().__init__() super().__init__()
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
self.linear_1 = LoRACompatibleLinear(in_channels, time_embed_dim) self.linear_1 = linear_cls(in_channels, time_embed_dim)
if cond_proj_dim is not None: if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
...@@ -180,7 +182,7 @@ class TimestepEmbedding(nn.Module): ...@@ -180,7 +182,7 @@ class TimestepEmbedding(nn.Module):
time_embed_dim_out = out_dim time_embed_dim_out = out_dim
else: else:
time_embed_dim_out = time_embed_dim time_embed_dim_out = time_embed_dim
self.linear_2 = LoRACompatibleLinear(time_embed_dim, time_embed_dim_out) self.linear_2 = linear_cls(time_embed_dim, time_embed_dim_out)
if post_act_fn is None: if post_act_fn is None:
self.post_act = None self.post_act = None
......
...@@ -32,10 +32,12 @@ from ..utils import ( ...@@ -32,10 +32,12 @@ from ..utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
MIN_PEFT_VERSION,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
_add_variant, _add_variant,
_get_model_file, _get_model_file,
check_peft_version,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_torch_version, is_torch_version,
...@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
_keys_to_ignore_on_load_unexpected = None _keys_to_ignore_on_load_unexpected = None
_hf_peft_config_loaded = False
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -292,6 +295,153 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -292,6 +295,153 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
""" """
self.set_use_memory_efficient_attention_xformers(False) self.set_use_memory_efficient_attention_xformers(False)
def add_adapter(self, adapter_config, adapter_name: str = "default") -> None:
r"""
Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned
to the adapter to follow the convention of the PEFT library.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT
[documentation](https://huggingface.co/docs/peft).
Args:
adapter_config (`[~peft.PeftConfig]`):
The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt
methods.
adapter_name (`str`, *optional*, defaults to `"default"`):
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
from peft import PeftConfig, inject_adapter_in_model
if not self._hf_peft_config_loaded:
self._hf_peft_config_loaded = True
elif adapter_name in self.peft_config:
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
if not isinstance(adapter_config, PeftConfig):
raise ValueError(
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)
# Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is
# handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here.
adapter_config.base_model_name_or_path = None
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)
def set_adapter(self, adapter_name: Union[str, List[str]]) -> None:
"""
Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
Args:
adapter_name (Union[str, List[str]])):
The list of adapters to set or the adapter name in case of single adapter.
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
if isinstance(adapter_name, str):
adapter_name = [adapter_name]
missing = set(adapter_name) - set(self.peft_config)
if len(missing) > 0:
raise ValueError(
f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)."
f" current loaded adapters are: {list(self.peft_config.keys())}"
)
from peft.tuners.tuners_utils import BaseTunerLayer
_adapters_has_been_set = False
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "set_adapter"):
module.set_adapter(adapter_name)
# Previous versions of PEFT does not support multi-adapter inference
elif not hasattr(module, "set_adapter") and len(adapter_name) != 1:
raise ValueError(
"You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT."
" `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`"
)
else:
module.active_adapter = adapter_name
_adapters_has_been_set = True
if not _adapters_has_been_set:
raise ValueError(
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
)
def disable_adapters(self) -> None:
r"""
Disable all adapters attached to the model and fallback to inference with the base model only.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
else:
# support for older PEFT versions
module.disable_adapters = True
def enable_adapters(self) -> None:
"""
Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the
list of adapters to enable.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=True)
else:
# support for older PEFT versions
module.disable_adapters = False
def active_adapters(self) -> List[str]:
"""
Gets the current list of active adapters of the model.
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
official documentation: https://huggingface.co/docs/peft
"""
check_peft_version(min_version=MIN_PEFT_VERSION)
if not self._hf_peft_config_loaded:
raise ValueError("No adapter loaded. Please load an adapter first.")
from peft.tuners.tuners_utils import BaseTunerLayer
for _, module in self.named_modules():
if isinstance(module, BaseTunerLayer):
return module.active_adapter
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
......
...@@ -20,6 +20,7 @@ import torch ...@@ -20,6 +20,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ..utils import USE_PEFT_BACKEND
from .activations import get_activation from .activations import get_activation
from .attention import AdaGroupNorm from .attention import AdaGroupNorm
from .attention_processor import SpatialNorm from .attention_processor import SpatialNorm
...@@ -149,12 +150,13 @@ class Upsample2D(nn.Module): ...@@ -149,12 +150,13 @@ class Upsample2D(nn.Module):
self.use_conv = use_conv self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose self.use_conv_transpose = use_conv_transpose
self.name = name self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
conv = None conv = None
if use_conv_transpose: if use_conv_transpose:
conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
elif use_conv: elif use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, padding=1) conv = conv_cls(self.channels, self.out_channels, 3, padding=1)
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if name == "conv": if name == "conv":
...@@ -193,12 +195,12 @@ class Upsample2D(nn.Module): ...@@ -193,12 +195,12 @@ class Upsample2D(nn.Module):
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv: if self.use_conv:
if self.name == "conv": if self.name == "conv":
if isinstance(self.conv, LoRACompatibleConv): if isinstance(self.conv, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.conv(hidden_states, scale) hidden_states = self.conv(hidden_states, scale)
else: else:
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
else: else:
if isinstance(self.Conv2d_0, LoRACompatibleConv): if isinstance(self.Conv2d_0, LoRACompatibleConv) and not USE_PEFT_BACKEND:
hidden_states = self.Conv2d_0(hidden_states, scale) hidden_states = self.Conv2d_0(hidden_states, scale)
else: else:
hidden_states = self.Conv2d_0(hidden_states) hidden_states = self.Conv2d_0(hidden_states)
...@@ -237,9 +239,10 @@ class Downsample2D(nn.Module): ...@@ -237,9 +239,10 @@ class Downsample2D(nn.Module):
self.padding = padding self.padding = padding
stride = 2 stride = 2
self.name = name self.name = name
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if use_conv: if use_conv:
conv = LoRACompatibleConv(self.channels, self.out_channels, 3, stride=stride, padding=padding) conv = conv_cls(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else: else:
assert self.channels == self.out_channels assert self.channels == self.out_channels
conv = nn.AvgPool2d(kernel_size=stride, stride=stride) conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
...@@ -255,15 +258,20 @@ class Downsample2D(nn.Module): ...@@ -255,15 +258,20 @@ class Downsample2D(nn.Module):
def forward(self, hidden_states, scale: float = 1.0): def forward(self, hidden_states, scale: float = 1.0):
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0: if self.use_conv and self.padding == 0:
pad = (0, 1, 0, 1) pad = (0, 1, 0, 1)
hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
assert hidden_states.shape[1] == self.channels assert hidden_states.shape[1] == self.channels
if not USE_PEFT_BACKEND:
if isinstance(self.conv, LoRACompatibleConv): if isinstance(self.conv, LoRACompatibleConv):
hidden_states = self.conv(hidden_states, scale) hidden_states = self.conv(hidden_states, scale)
else: else:
hidden_states = self.conv(hidden_states) hidden_states = self.conv(hidden_states)
else:
hidden_states = self.conv(hidden_states)
return hidden_states return hidden_states
...@@ -608,6 +616,9 @@ class ResnetBlock2D(nn.Module): ...@@ -608,6 +616,9 @@ class ResnetBlock2D(nn.Module):
self.time_embedding_norm = time_embedding_norm self.time_embedding_norm = time_embedding_norm
self.skip_time_act = skip_time_act self.skip_time_act = skip_time_act
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
if groups_out is None: if groups_out is None:
groups_out = groups groups_out = groups
...@@ -618,13 +629,13 @@ class ResnetBlock2D(nn.Module): ...@@ -618,13 +629,13 @@ class ResnetBlock2D(nn.Module):
else: else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = LoRACompatibleConv(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None: if temb_channels is not None:
if self.time_embedding_norm == "default": if self.time_embedding_norm == "default":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, out_channels) self.time_emb_proj = linear_cls(temb_channels, out_channels)
elif self.time_embedding_norm == "scale_shift": elif self.time_embedding_norm == "scale_shift":
self.time_emb_proj = LoRACompatibleLinear(temb_channels, 2 * out_channels) self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial": elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
self.time_emb_proj = None self.time_emb_proj = None
else: else:
...@@ -641,7 +652,7 @@ class ResnetBlock2D(nn.Module): ...@@ -641,7 +652,7 @@ class ResnetBlock2D(nn.Module):
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
conv_2d_out_channels = conv_2d_out_channels or out_channels conv_2d_out_channels = conv_2d_out_channels or out_channels
self.conv2 = LoRACompatibleConv(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
self.nonlinearity = get_activation(non_linearity) self.nonlinearity = get_activation(non_linearity)
...@@ -667,7 +678,7 @@ class ResnetBlock2D(nn.Module): ...@@ -667,7 +678,7 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = None self.conv_shortcut = None
if self.use_in_shortcut: if self.use_in_shortcut:
self.conv_shortcut = LoRACompatibleConv( self.conv_shortcut = conv_cls(
in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
) )
...@@ -708,12 +719,16 @@ class ResnetBlock2D(nn.Module): ...@@ -708,12 +719,16 @@ class ResnetBlock2D(nn.Module):
else self.downsample(hidden_states) else self.downsample(hidden_states)
) )
hidden_states = self.conv1(hidden_states, scale) hidden_states = self.conv1(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv1(hidden_states)
if self.time_emb_proj is not None: if self.time_emb_proj is not None:
if not self.skip_time_act: if not self.skip_time_act:
temb = self.nonlinearity(temb) temb = self.nonlinearity(temb)
temb = self.time_emb_proj(temb, scale)[:, :, None, None] temb = (
self.time_emb_proj(temb, scale)[:, :, None, None]
if not USE_PEFT_BACKEND
else self.time_emb_proj(temb)[:, :, None, None]
)
if temb is not None and self.time_embedding_norm == "default": if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb hidden_states = hidden_states + temb
...@@ -730,10 +745,12 @@ class ResnetBlock2D(nn.Module): ...@@ -730,10 +745,12 @@ class ResnetBlock2D(nn.Module):
hidden_states = self.nonlinearity(hidden_states) hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states, scale) hidden_states = self.conv2(hidden_states, scale) if not USE_PEFT_BACKEND else self.conv2(hidden_states)
if self.conv_shortcut is not None: if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor, scale) input_tensor = (
self.conv_shortcut(input_tensor, scale) if not USE_PEFT_BACKEND else self.conv_shortcut(input_tensor)
)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
......
...@@ -20,7 +20,7 @@ from torch import nn ...@@ -20,7 +20,7 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput, deprecate from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed from .embeddings import PatchEmbed
from .lora import LoRACompatibleConv, LoRACompatibleLinear from .lora import LoRACompatibleConv, LoRACompatibleLinear
...@@ -100,6 +100,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -100,6 +100,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.attention_head_dim = attention_head_dim self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration # Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None) self.is_input_continuous = (in_channels is not None) and (patch_size is None)
...@@ -139,9 +142,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -139,9 +142,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection: if use_linear_projection:
self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) self.proj_in = linear_cls(in_channels, inner_dim)
else: else:
self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
...@@ -197,9 +200,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -197,9 +200,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
# TODO: should use out_channels for continuous projections # TODO: should use out_channels for continuous projections
if use_linear_projection: if use_linear_projection:
self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) self.proj_out = linear_cls(inner_dim, in_channels)
else: else:
self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized: elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim) self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
...@@ -292,13 +295,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -292,13 +295,21 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states, scale=lora_scale) hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else: else:
inner_dim = hidden_states.shape[1] inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states, scale=lora_scale) hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized: elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states) hidden_states = self.latent_image_embedding(hidden_states)
...@@ -334,9 +345,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin): ...@@ -334,9 +345,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
if self.is_input_continuous: if self.is_input_continuous:
if not self.use_linear_projection: if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states, scale=lora_scale) hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else: else:
hidden_states = self.proj_out(hidden_states, scale=lora_scale) hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual output = hidden_states + residual
......
...@@ -20,7 +20,7 @@ import torch.utils.checkpoint ...@@ -20,7 +20,7 @@ import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging from ..utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from .activations import get_activation from .activations import get_activation
from .attention_processor import ( from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, ADDED_KV_ATTENTION_PROCESSORS,
...@@ -995,6 +995,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -995,6 +995,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
...@@ -1094,6 +1097,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -1094,6 +1097,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
if not return_dict: if not return_dict:
return (sample,) return (sample,)
......
...@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -25,7 +25,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
...@@ -304,7 +311,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -304,7 +311,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -432,7 +439,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -432,7 +439,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
...@@ -668,6 +675,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -668,6 +675,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
...@@ -689,9 +697,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -689,9 +697,8 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
...@@ -700,7 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -700,7 +707,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
...@@ -29,6 +29,7 @@ from ...models.lora import adjust_lora_scale_text_encoder ...@@ -29,6 +29,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
PIL_INTERPOLATION, PIL_INTERPOLATION,
USE_PEFT_BACKEND,
deprecate, deprecate,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -309,7 +310,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -309,7 +310,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -437,7 +438,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -437,7 +438,7 @@ class AltDiffusionImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -27,7 +27,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -27,7 +27,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
...@@ -287,7 +294,7 @@ class StableDiffusionControlNetPipeline( ...@@ -287,7 +294,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -415,7 +422,7 @@ class StableDiffusionControlNetPipeline( ...@@ -415,7 +422,7 @@ class StableDiffusionControlNetPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel ...@@ -27,6 +27,7 @@ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND,
deprecate, deprecate,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -317,7 +318,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -317,7 +318,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -445,7 +446,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -445,7 +446,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -28,7 +28,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -28,7 +28,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion import StableDiffusionPipelineOutput from ..stable_diffusion import StableDiffusionPipelineOutput
...@@ -438,7 +445,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -438,7 +445,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -566,7 +573,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -566,7 +573,7 @@ class StableDiffusionControlNetInpaintPipeline(
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -33,6 +33,7 @@ from ...models.attention_processor import ( ...@@ -33,6 +33,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND,
is_invisible_watermark_available, is_invisible_watermark_available,
logging, logging,
replace_example_docstring, replace_example_docstring,
...@@ -316,7 +317,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -316,7 +317,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -458,7 +459,7 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -458,7 +459,7 @@ class StableDiffusionXLControlNetInpaintPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -35,7 +35,7 @@ from ...models.attention_processor import ( ...@@ -35,7 +35,7 @@ from ...models.attention_processor import (
) )
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import is_compiled_module, randn_tensor from ...utils.torch_utils import is_compiled_module, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -285,7 +285,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -285,7 +285,7 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -427,7 +427,7 @@ class StableDiffusionXLControlNetPipeline( ...@@ -427,7 +427,7 @@ class StableDiffusionXLControlNetPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -36,6 +36,7 @@ from ...models.attention_processor import ( ...@@ -36,6 +36,7 @@ from ...models.attention_processor import (
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import ( from ...utils import (
USE_PEFT_BACKEND,
logging, logging,
replace_example_docstring, replace_example_docstring,
scale_lora_layers, scale_lora_layers,
...@@ -328,7 +329,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -328,7 +329,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -470,7 +471,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -470,7 +471,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -27,7 +27,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import DDIMScheduler from ...schedulers import DDIMScheduler
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput from .pipeline_output import StableDiffusionPipelineOutput
...@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -436,7 +436,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -436,7 +436,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -25,7 +25,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa ...@@ -25,7 +25,14 @@ from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoa
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import StableDiffusionPipelineOutput from .pipeline_output import StableDiffusionPipelineOutput
...@@ -297,7 +304,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -297,7 +304,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -425,7 +432,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -425,7 +432,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
...@@ -658,6 +665,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -658,6 +665,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
# 0. Default height and width to unet # 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
...@@ -679,9 +687,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -679,9 +687,8 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
do_classifier_free_guidance = guidance_scale > 1.0 do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt, prompt,
device, device,
...@@ -690,7 +697,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -690,7 +697,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
negative_prompt, negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale, lora_scale=lora_scale,
clip_skip=clip_skip, clip_skip=clip_skip,
) )
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
......
...@@ -27,7 +27,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel ...@@ -27,7 +27,14 @@ from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.attention_processor import Attention from ...models.attention_processor import Attention
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput from . import StableDiffusionPipelineOutput
...@@ -332,7 +339,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -332,7 +339,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -460,7 +467,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -460,7 +467,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -28,7 +28,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -28,7 +28,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import PIL_INTERPOLATION, USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -341,7 +341,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -341,7 +341,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment