"vscode:/vscode.git/clone" did not exist on "b9b7cfc602d68e71b4e4039d15dddfe578df9db2"
Unverified Commit 493f9529 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT` / `LoRA`] PEFT integration - text encoder (#5058)



* more fixes

* up

* up

* style

* add in setup

* oops

* more changes

* v1 rzfactor CI

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* few todos

* protect torch import

* style

* fix fuse text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* replace with `recurse_replace_peft_layers`

* keep old modules for BC

* adjustments on `adjust_lora_scale_text_encoder`

* nit

* move tests

* add conversion utils

* remove unneeded methods

* use class method instead

* oops

* use `base_version`

* fix examples

* fix CI

* fix weird error with python 3.8

* fix

* better fix

* style

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add comment

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* conv2d support for recurse remove

* added docstrings

* more docstring

* add deprecate

* revert

* try to fix merge conflicts

* v1 tests

* add new decorator

* add saving utilities test

* adapt tests a bit

* add save / from_pretrained tests

* add saving tests

* add scale tests

* fix deps tests

* fix lora CI

* fix tests

* add comment

* fix

* style

* add slow tests

* slow tests pass

* style

* Update src/diffusers/utils/import_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* circumvents pattern finding issue

* left a todo

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update hub path

* add lora workflow

* fix

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent b32555a2
name: Fast tests for PRs - PEFT backend
on:
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
run_fast_tests:
strategy:
fail-fast: false
matrix:
config:
- name: LoRA
framework: lora
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_lora
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install -U git+https://github.com/huggingface/peft.git
- name: Environment
run: |
python utils/print_env.py
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
if: ${{ matrix.config.framework == 'lora' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/test_lora_layers_peft.py
\ No newline at end of file
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
import re import re
from collections import defaultdict from collections import defaultdict
...@@ -23,6 +24,7 @@ import requests ...@@ -23,6 +24,7 @@ import requests
import safetensors import safetensors
import torch import torch
from huggingface_hub import hf_hub_download, model_info from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
...@@ -30,11 +32,15 @@ from .utils import ( ...@@ -30,11 +32,15 @@ from .utils import (
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HF_HUB_OFFLINE, HF_HUB_OFFLINE,
_get_model_file, _get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_omegaconf_available, is_omegaconf_available,
is_peft_available,
is_transformers_available, is_transformers_available,
logging, logging,
recurse_remove_peft_layers,
) )
from .utils.import_utils import BACKENDS_MAPPING from .utils.import_utils import BACKENDS_MAPPING
...@@ -61,6 +67,21 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" ...@@ -61,6 +67,21 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
class PatchedLoraProjection(nn.Module): class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None): def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__() super().__init__()
...@@ -1077,6 +1098,7 @@ class LoraLoaderMixin: ...@@ -1077,6 +1098,7 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME unet_name = UNET_NAME
num_fused_loras = 0 num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
""" """
...@@ -1268,6 +1290,7 @@ class LoraLoaderMixin: ...@@ -1268,6 +1290,7 @@ class LoraLoaderMixin:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
network_alphas = None network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
if all( if all(
( (
k.startswith("lora_te_") k.startswith("lora_te_")
...@@ -1520,55 +1543,35 @@ class LoraLoaderMixin: ...@@ -1520,55 +1543,35 @@ class LoraLoaderMixin:
if len(text_encoder_lora_state_dict) > 0: if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.") logger.info(f"Loading {prefix}.")
rank = {} rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()): if cls.use_peft_backend:
# Convert from the old naming convention to the new naming convention. # convert state dict
# text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
# Previously, the old LoRA layers were stored on the state dict at the
# same level as the attention block i.e. for name, _ in text_encoder_attn_modules(text_encoder):
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`. rank_key = f"{name}.out_proj.lora_B.weight"
# rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
# This is no actual module at that point, they were monkey patched on to the
# existing module. We want to be able to load them via their actual state dict. patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
# They're in `PatchedLoraProjection.lora_linear_layer` now. if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
for name, _ in text_encoder_attn_modules(text_encoder): for name, _ in text_encoder_attn_modules(text_encoder):
text_encoder_lora_state_dict[ rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
f"{name}.q_proj.lora_linear_layer.up.weight" rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
text_encoder_lora_state_dict[ patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
f"{name}.k_proj.lora_linear_layer.up.weight" if patch_mlp:
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight") for name, _ in text_encoder_mlp_modules(text_encoder):
text_encoder_lora_state_dict[ rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
f"{name}.v_proj.lora_linear_layer.up.weight" rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight") rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
text_encoder_lora_state_dict[ rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
f"{name}.out_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
if network_alphas is not None: if network_alphas is not None:
alpha_keys = [ alpha_keys = [
...@@ -1578,56 +1581,79 @@ class LoraLoaderMixin: ...@@ -1578,56 +1581,79 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
} }
cls._modify_text_encoder( if cls.use_peft_backend:
text_encoder, from peft import LoraConfig
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
is_pipeline_offloaded = _pipeline is not None and any( lora_rank = list(rank.values())[0]
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values() # By definition, the scale should be alpha divided by rank.
) # https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
if is_pipeline_offloaded and low_cpu_mem_usage: alpha = lora_scale * lora_rank
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage: target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
device = next(iter(text_encoder_lora_state_dict.values())).device if patch_mlp:
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype target_modules += ["fc1", "fc2"]
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype # TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
) lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
else: else:
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False) cls._modify_text_encoder(
unexpected_keys = load_state_dict_results.unexpected_keys text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if len(unexpected_keys) != 0: is_pipeline_offloaded = _pipeline is not None and any(
raise ValueError( isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" for c in _pipeline.components.values()
) )
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(
text_encoder_lora_state_dict, strict=False
)
unexpected_keys = load_state_dict_results.unexpected_keys
# <Unsafe code if len(unexpected_keys) != 0:
# We can be sure that the following works as all we do is change the dtype and device of the text encoder raise ValueError(
# Now we remove any existing hooks to f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
is_model_cpu_offload = False )
is_sequential_cpu_offload = False
if _pipeline is not None: # <Unsafe code
for _, component in _pipeline.components.items(): # We can be sure that the following works as all we do is change the dtype and device of the text encoder
if isinstance(component, torch.nn.Module): # Now we remove any existing hooks to
if hasattr(component, "_hf_hook"): is_model_cpu_offload = False
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload) is_sequential_cpu_offload = False
is_sequential_cpu_offload = isinstance( if _pipeline is not None:
getattr(component, "_hf_hook"), AlignDevicesHook for _, component in _pipeline.components.items():
) if isinstance(component, torch.nn.Module):
logger.info( if hasattr(component, "_hf_hook"):
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again." is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
) is_sequential_cpu_offload = isinstance(
remove_hook_from_module(component, recurse=is_sequential_cpu_offload) getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
...@@ -1645,10 +1671,27 @@ class LoraLoaderMixin: ...@@ -1645,10 +1671,27 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0 return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) if self.use_peft_backend:
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
if hasattr(self, "text_encoder"):
remove_method(self.text_encoder)
if self.use_peft_backend:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2)
if self.use_peft_backend:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
@classmethod @classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder): def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection): if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_linear_layer = None attn_module.q_proj.lora_linear_layer = None
...@@ -1675,6 +1718,7 @@ class LoraLoaderMixin: ...@@ -1675,6 +1718,7 @@ class LoraLoaderMixin:
r""" r"""
Monkey-patches the forward passes of attention modules of the text encoder. Monkey-patches the forward passes of attention modules of the text encoder.
""" """
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters): def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
...@@ -2049,24 +2093,38 @@ class LoraLoaderMixin: ...@@ -2049,24 +2093,38 @@ class LoraLoaderMixin:
if fuse_unet: if fuse_unet:
self.unet.fuse_lora(lora_scale) self.unet.fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder): if self.use_peft_backend:
for _, attn_module in text_encoder_attn_modules(text_encoder): from peft.tuners.tuners_utils import BaseTunerLayer
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder): def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
if isinstance(mlp_module.fc1, PatchedLoraProjection): for module in text_encoder.modules():
mlp_module.fc1._fuse_lora(lora_scale) if isinstance(module, BaseTunerLayer):
mlp_module.fc2._fuse_lora(lora_scale) if lora_scale != 1.0:
module.scale_layer(lora_scale)
module.merge()
else:
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
if fuse_text_encoder: if fuse_text_encoder:
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder) fuse_text_encoder_lora(self.text_encoder, lora_scale)
if hasattr(self, "text_encoder_2"): if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2) fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True): def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r""" r"""
...@@ -2088,18 +2146,29 @@ class LoraLoaderMixin: ...@@ -2088,18 +2146,29 @@ class LoraLoaderMixin:
if unfuse_unet: if unfuse_unet:
self.unet.unfuse_lora() self.unet.unfuse_lora()
def unfuse_text_encoder_lora(text_encoder): if self.use_peft_backend:
for _, attn_module in text_encoder_attn_modules(text_encoder): from peft.tuners.tuner_utils import BaseTunerLayer
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora() def unfuse_text_encoder_lora(text_encoder):
attn_module.k_proj._unfuse_lora() for module in text_encoder.modules():
attn_module.v_proj._unfuse_lora() if isinstance(module, BaseTunerLayer):
attn_module.out_proj._unfuse_lora() module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder): for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection): if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora() mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora() mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder: if unfuse_text_encoder:
if hasattr(self, "text_encoder"): if hasattr(self, "text_encoder"):
...@@ -2810,5 +2879,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): ...@@ -2810,5 +2879,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
) )
def _remove_text_encoder_monkey_patch(self): def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) if self.use_peft_backend:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2)
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
...@@ -25,18 +25,25 @@ from ..utils import logging ...@@ -25,18 +25,25 @@ from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0): def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
for _, attn_module in text_encoder_attn_modules(text_encoder): if use_peft_backend:
if isinstance(attn_module.q_proj, PatchedLoraProjection): from peft.tuners.lora import LoraLayer
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale for module in text_encoder.modules():
attn_module.v_proj.lora_scale = lora_scale if isinstance(module, LoraLayer):
attn_module.out_proj.lora_scale = lora_scale module.scaling[module.active_adapter] = lora_scale
else:
for _, mlp_module in text_encoder_mlp_modules(text_encoder): for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection): if isinstance(attn_module.q_proj, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale attn_module.q_proj.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class LoRALinearLayer(nn.Module): class LoRALinearLayer(nn.Module):
......
...@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL ...@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline( ...@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline( ...@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline( ...@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -315,8 +315,8 @@ class StableDiffusionXLControlNetInpaintPipeline( ...@@ -315,8 +315,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
......
...@@ -288,8 +288,8 @@ class StableDiffusionXLControlNetPipeline( ...@@ -288,8 +288,8 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
......
...@@ -326,8 +326,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline( ...@@ -326,8 +326,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
......
...@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo ...@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion ...@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader ...@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM ...@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline): ...@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline): ...@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline( ...@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
...@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline( ...@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment