Unverified Commit 493f9529 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`PEFT` / `LoRA`] PEFT integration - text encoder (#5058)



* more fixes

* up

* up

* style

* add in setup

* oops

* more changes

* v1 rzfactor CI

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* few todos

* protect torch import

* style

* fix fuse text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* replace with `recurse_replace_peft_layers`

* keep old modules for BC

* adjustments on `adjust_lora_scale_text_encoder`

* nit

* move tests

* add conversion utils

* remove unneeded methods

* use class method instead

* oops

* use `base_version`

* fix examples

* fix CI

* fix weird error with python 3.8

* fix

* better fix

* style

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* add comment

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* conv2d support for recurse remove

* added docstrings

* more docstring

* add deprecate

* revert

* try to fix merge conflicts

* v1 tests

* add new decorator

* add saving utilities test

* adapt tests a bit

* add save / from_pretrained tests

* add saving tests

* add scale tests

* fix deps tests

* fix lora CI

* fix tests

* add comment

* fix

* style

* add slow tests

* slow tests pass

* style

* Update src/diffusers/utils/import_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* circumvents pattern finding issue

* left a todo

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* update hub path

* add lora workflow

* fix

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent b32555a2
name: Fast tests for PRs - PEFT backend
on:
pull_request:
branches:
- main
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60
jobs:
run_fast_tests:
strategy:
fail-fast: false
matrix:
config:
- name: LoRA
framework: lora
runner: docker-cpu
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_lora
name: ${{ matrix.config.name }}
runs-on: ${{ matrix.config.runner }}
container:
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
run:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
run: |
apt-get update && apt-get install libsndfile1-dev libgl1 -y
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate.git
python -m pip install -U git+https://github.com/huggingface/transformers.git
python -m pip install -U git+https://github.com/huggingface/peft.git
- name: Environment
run: |
python utils/print_env.py
- name: Run fast PyTorch LoRA CPU tests with PEFT backend
if: ${{ matrix.config.framework == 'lora' }}
run: |
python -m pytest -n 2 --max-worker-restart=0 --dist=loadfile \
-s -v \
--make-reports=tests_${{ matrix.config.report }} \
tests/lora/test_lora_layers_peft.py
\ No newline at end of file
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import re
from collections import defaultdict
......@@ -23,6 +24,7 @@ import requests
import safetensors
import torch
from huggingface_hub import hf_hub_download, model_info
from packaging import version
from torch import nn
from .models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
......@@ -30,11 +32,15 @@ from .utils import (
DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
_get_model_file,
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
deprecate,
is_accelerate_available,
is_omegaconf_available,
is_peft_available,
is_transformers_available,
logging,
recurse_remove_peft_layers,
)
from .utils.import_utils import BACKENDS_MAPPING
......@@ -61,6 +67,21 @@ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse("0.5")
_required_transformers_version = version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
LORA_DEPRECATION_MESSAGE = "You are using an old version of LoRA backend. This will be deprecated in the next releases in favor of PEFT make sure to install the latest PEFT and transformers packages in the future."
class PatchedLoraProjection(nn.Module):
def __init__(self, regular_linear_layer, lora_scale=1, network_alpha=None, rank=4, dtype=None):
super().__init__()
......@@ -1077,6 +1098,7 @@ class LoraLoaderMixin:
text_encoder_name = TEXT_ENCODER_NAME
unet_name = UNET_NAME
num_fused_loras = 0
use_peft_backend = USE_PEFT_BACKEND
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
......@@ -1268,6 +1290,7 @@ class LoraLoaderMixin:
state_dict = pretrained_model_name_or_path_or_dict
network_alphas = None
# TODO: replace it with a method from `state_dict_utils`
if all(
(
k.startswith("lora_te_")
......@@ -1520,55 +1543,35 @@ class LoraLoaderMixin:
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {prefix}.")
rank = {}
text_encoder_lora_state_dict = convert_state_dict_to_diffusers(text_encoder_lora_state_dict)
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
#
# Previously, the old LoRA layers were stored on the state dict at the
# same level as the attention block i.e.
# `text_model.encoder.layers.11.self_attn.to_out_lora.up.weight`.
#
# This is no actual module at that point, they were monkey patched on to the
# existing module. We want to be able to load them via their actual state dict.
# They're in `PatchedLoraProjection.lora_linear_layer` now.
if cls.use_peft_backend:
# convert state dict
text_encoder_lora_state_dict = convert_state_dict_to_peft(text_encoder_lora_state_dict)
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_B.weight"
rank[rank_key] = text_encoder_lora_state_dict[rank_key].shape[1]
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_B.weight"
rank_key_fc2 = f"{name}.fc2.lora_B.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
else:
for name, _ in text_encoder_attn_modules(text_encoder):
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.up.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.up.weight")
text_encoder_lora_state_dict[
f"{name}.q_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_q_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.k_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_k_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.v_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_v_lora.down.weight")
text_encoder_lora_state_dict[
f"{name}.out_proj.lora_linear_layer.down.weight"
] = text_encoder_lora_state_dict.pop(f"{name}.to_out_lora.down.weight")
for name, _ in text_encoder_attn_modules(text_encoder):
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank.update({rank_key_fc1: text_encoder_lora_state_dict[rank_key_fc1].shape[1]})
rank.update({rank_key_fc2: text_encoder_lora_state_dict[rank_key_fc2].shape[1]})
rank_key = f"{name}.out_proj.lora_linear_layer.up.weight"
rank.update({rank_key: text_encoder_lora_state_dict[rank_key].shape[1]})
patch_mlp = any(".mlp." in key for key in text_encoder_lora_state_dict.keys())
if patch_mlp:
for name, _ in text_encoder_mlp_modules(text_encoder):
rank_key_fc1 = f"{name}.fc1.lora_linear_layer.up.weight"
rank_key_fc2 = f"{name}.fc2.lora_linear_layer.up.weight"
rank[rank_key_fc1] = text_encoder_lora_state_dict[rank_key_fc1].shape[1]
rank[rank_key_fc2] = text_encoder_lora_state_dict[rank_key_fc2].shape[1]
if network_alphas is not None:
alpha_keys = [
......@@ -1578,56 +1581,79 @@ class LoraLoaderMixin:
k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
}
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if cls.use_peft_backend:
from peft import LoraConfig
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
lora_rank = list(rank.values())[0]
# By definition, the scale should be alpha divided by rank.
# https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/tuners/lora/layer.py#L71
alpha = lora_scale * lora_rank
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
if patch_mlp:
target_modules += ["fc1", "fc2"]
# TODO: support multi alpha / rank: https://github.com/huggingface/peft/pull/873
lora_config = LoraConfig(r=lora_rank, target_modules=target_modules, lora_alpha=alpha)
text_encoder.load_adapter(adapter_state_dict=text_encoder_lora_state_dict, peft_config=lora_config)
is_model_cpu_offload = False
is_sequential_cpu_offload = False
else:
load_state_dict_results = text_encoder.load_state_dict(text_encoder_lora_state_dict, strict=False)
unexpected_keys = load_state_dict_results.unexpected_keys
cls._modify_text_encoder(
text_encoder,
lora_scale,
network_alphas,
rank=rank,
patch_mlp=patch_mlp,
low_cpu_mem_usage=low_cpu_mem_usage,
)
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook")
for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
unexpected_keys = load_model_dict_into_meta(
text_encoder, text_encoder_lora_state_dict, device=device, dtype=dtype
)
else:
load_state_dict_results = text_encoder.load_state_dict(
text_encoder_lora_state_dict, strict=False
)
unexpected_keys = load_state_dict_results.unexpected_keys
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
if len(unexpected_keys) != 0:
raise ValueError(
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
......@@ -1645,10 +1671,27 @@ class LoraLoaderMixin:
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
if self.use_peft_backend:
remove_method = recurse_remove_peft_layers
else:
remove_method = self._remove_text_encoder_monkey_patch_classmethod
if hasattr(self, "text_encoder"):
remove_method(self.text_encoder)
if self.use_peft_backend:
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
if hasattr(self, "text_encoder_2"):
remove_method(self.text_encoder_2)
if self.use_peft_backend:
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
@classmethod
def _remove_text_encoder_monkey_patch_classmethod(cls, text_encoder):
deprecate("_remove_text_encoder_monkey_patch_classmethod", "0.23", LORA_DEPRECATION_MESSAGE)
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_linear_layer = None
......@@ -1675,6 +1718,7 @@ class LoraLoaderMixin:
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
"""
deprecate("_modify_text_encoder", "0.23", LORA_DEPRECATION_MESSAGE)
def create_patched_linear_lora(model, network_alpha, rank, dtype, lora_parameters):
linear_layer = model.regular_linear_layer if isinstance(model, PatchedLoraProjection) else model
......@@ -2049,24 +2093,38 @@ class LoraLoaderMixin:
if fuse_unet:
self.unet.fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
if self.use_peft_backend:
from peft.tuners.tuners_utils import BaseTunerLayer
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
if lora_scale != 1.0:
module.scale_layer(lora_scale)
module.merge()
else:
deprecate("fuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._fuse_lora(lora_scale)
attn_module.k_proj._fuse_lora(lora_scale)
attn_module.v_proj._fuse_lora(lora_scale)
attn_module.out_proj._fuse_lora(lora_scale)
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._fuse_lora(lora_scale)
mlp_module.fc2._fuse_lora(lora_scale)
if fuse_text_encoder:
if hasattr(self, "text_encoder"):
fuse_text_encoder_lora(self.text_encoder)
fuse_text_encoder_lora(self.text_encoder, lora_scale)
if hasattr(self, "text_encoder_2"):
fuse_text_encoder_lora(self.text_encoder_2)
fuse_text_encoder_lora(self.text_encoder_2, lora_scale)
def unfuse_lora(self, unfuse_unet: bool = True, unfuse_text_encoder: bool = True):
r"""
......@@ -2088,18 +2146,29 @@ class LoraLoaderMixin:
if unfuse_unet:
self.unet.unfuse_lora()
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
if self.use_peft_backend:
from peft.tuners.tuner_utils import BaseTunerLayer
def unfuse_text_encoder_lora(text_encoder):
for module in text_encoder.modules():
if isinstance(module, BaseTunerLayer):
module.unmerge()
else:
deprecate("unfuse_text_encoder_lora", "0.23", LORA_DEPRECATION_MESSAGE)
def unfuse_text_encoder_lora(text_encoder):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj._unfuse_lora()
attn_module.k_proj._unfuse_lora()
attn_module.v_proj._unfuse_lora()
attn_module.out_proj._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1._unfuse_lora()
mlp_module.fc2._unfuse_lora()
if unfuse_text_encoder:
if hasattr(self, "text_encoder"):
......@@ -2810,5 +2879,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
if self.use_peft_backend:
recurse_remove_peft_layers(self.text_encoder)
# TODO: @younesbelkada handle this in transformers side
del self.text_encoder.peft_config
self.text_encoder._hf_peft_config_loaded = None
recurse_remove_peft_layers(self.text_encoder_2)
del self.text_encoder_2.peft_config
self.text_encoder_2._hf_peft_config_loaded = None
else:
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
......@@ -25,18 +25,25 @@ from ..utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0):
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
def adjust_lora_scale_text_encoder(text_encoder, lora_scale: float = 1.0, use_peft_backend: bool = False):
if use_peft_backend:
from peft.tuners.lora import LoraLayer
for module in text_encoder.modules():
if isinstance(module, LoraLayer):
module.scaling[module.active_adapter] = lora_scale
else:
for _, attn_module in text_encoder_attn_modules(text_encoder):
if isinstance(attn_module.q_proj, PatchedLoraProjection):
attn_module.q_proj.lora_scale = lora_scale
attn_module.k_proj.lora_scale = lora_scale
attn_module.v_proj.lora_scale = lora_scale
attn_module.out_proj.lora_scale = lora_scale
for _, mlp_module in text_encoder_mlp_modules(text_encoder):
if isinstance(mlp_module.fc1, PatchedLoraProjection):
mlp_module.fc1.lora_scale = lora_scale
mlp_module.fc2.lora_scale = lora_scale
class LoRALinearLayer(nn.Module):
......
......@@ -303,7 +303,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -301,7 +301,7 @@ class AltDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -291,7 +291,7 @@ class StableDiffusionControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -315,7 +315,7 @@ class StableDiffusionControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -442,7 +442,7 @@ class StableDiffusionControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -315,8 +315,8 @@ class StableDiffusionXLControlNetInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
prompt = [prompt] if isinstance(prompt, str) else prompt
......
......@@ -288,8 +288,8 @@ class StableDiffusionXLControlNetPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
prompt = [prompt] if isinstance(prompt, str) else prompt
......
......@@ -326,8 +326,8 @@ class StableDiffusionXLControlNetImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale, self.use_peft_backend)
prompt = [prompt] if isinstance(prompt, str) else prompt
......
......@@ -308,7 +308,7 @@ class CycleDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -301,7 +301,7 @@ class StableDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lo
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -332,7 +332,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, TextualInversion
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -213,7 +213,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline, TextualInversionLoader
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -481,7 +481,7 @@ class StableDiffusionDiffEditPipeline(DiffusionPipeline, TextualInversionLoaderM
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -278,7 +278,7 @@ class StableDiffusionGLIGENPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -309,7 +309,7 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -302,7 +302,7 @@ class StableDiffusionImg2ImgPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
......@@ -375,7 +375,7 @@ class StableDiffusionInpaintPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale, self.use_peft_backend)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment