Unverified Commit 2bfa55f4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `PEFT` / `LoRA`] Integrate PEFT into Unet (#5151)



* v1

* add tests and fix previous failing tests

* fix CI

* add tests + v1 `PeftLayerScaler`

* style

* add scale retrieving mechanism system

* fix CI

* up

* up

* simple approach --> not same results for some reason

* fix issues

* fix copies

* remove unneeded method

* active adapters!

* fix merge conflicts

* up

* up

* kohya - test-1

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix scale

* fix copies

* add comment

* multi adapters

* fix tests

* oops

* v1 faster loading - in progress

* Revert "v1 faster loading - in progress"

This reverts commit ac925f81321e95fc8168184c3346bf3d75404d5a.

* kohya same generation

* fix some slow tests

* peft integration features for unet lora

1. Support for Multiple ranks/alphas
2. Support for Multiple active adapters
3. Support for enabling/disabling LoRAs

* fix `get_peft_kwargs`

* Update loaders.py

* add some tests

* add unfuse tests

* fix tests

* up

* add set adapter from sourab and tests

* fix multi adapter tests

* style & quality

* style

* remove comment

* fix `adapter_name` issues

* fix unet adapter name for sdxl

* fix enabling/disabling adapters

* fix fuse / unfuse unet

* nit

* fix

* up

* fix cpu offloading

* fix another slow test

* fix another offload test

* add more tests

* all slow tests pass

* style

* fix alpha pattern for unet and text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* up

* up

* clarify comment

* comments

* change comment order

* change comment order

* stylr & quality

* Update tests/lora/test_lora_layers_peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix bugs and add tests

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* refactor

* suggestion

* add break statemebt

* add compile tests

* move slow tests to peft tests as I modified them

* quality

* refactor a bit

* style

* change import

* style

* fix CI

* refactor slow tests one last time

* style

* oops

* oops

* oops

* final tweak tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* comments

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* remove comments

* more comments

* try

* revert

* add `safe_merge` tests

* add comment

* style, comments and run tests in fp16

* add warnings

* fix doc test

* replace with `adapter_weights`

* add `get_active_adapters()`

* expose `get_list_adapters` method

* better error message

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* trigger slow lora tests

* fix tests

* maybe fix last test

* revert

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* move `MIN_PEFT_VERSION`

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* let's not use class variable

* fix few nits

* change a bit offloading logic

* check earlier

* rm unneeded block

* break long line

* return empty list

* change logic a bit and address comments

* add typehint

* remove parenthesis

* fix

* revert to fp16 in tests

* add to gpu

* revert to old test

* style

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* change indent

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 9bc55e8b
...@@ -31,7 +31,14 @@ from ...models.attention_processor import ( ...@@ -31,7 +31,14 @@ from ...models.attention_processor import (
) )
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
...@@ -283,7 +290,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -283,7 +290,7 @@ class StableDiffusionXLAdapterPipeline(
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else: else:
...@@ -425,7 +432,7 @@ class StableDiffusionXLAdapterPipeline( ...@@ -425,7 +432,7 @@ class StableDiffusionXLAdapterPipeline(
bs_embed * num_images_per_prompt, -1 bs_embed * num_images_per_prompt, -1
) )
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2) unscale_lora_layers(self.text_encoder_2)
......
...@@ -23,7 +23,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -23,7 +23,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
...@@ -224,7 +231,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -224,7 +231,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -352,7 +359,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora ...@@ -352,7 +359,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -24,7 +24,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -24,7 +24,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput from . import TextToVideoSDPipelineOutput
...@@ -286,7 +293,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -286,7 +293,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -414,7 +421,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor ...@@ -414,7 +421,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -18,7 +18,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin ...@@ -18,7 +18,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.outputs import BaseOutput from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
...@@ -426,7 +426,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -426,7 +426,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
self._lora_scale = lora_scale self._lora_scale = lora_scale
# dynamically adjust the LoRA scale # dynamically adjust the LoRA scale
if not self.use_peft_backend: if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else: else:
scale_lora_layers(self.text_encoder, lora_scale) scale_lora_layers(self.text_encoder, lora_scale)
...@@ -554,7 +554,7 @@ class UniDiffuserPipeline(DiffusionPipeline): ...@@ -554,7 +554,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder)
......
...@@ -31,7 +31,7 @@ from ...models.embeddings import ( ...@@ -31,7 +31,7 @@ from ...models.embeddings import (
) )
from ...models.transformer_2d import Transformer2DModel from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import is_torch_version, logging from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import apply_freeu from ...utils.torch_utils import apply_freeu
...@@ -1211,6 +1211,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1211,6 +1211,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
...@@ -1310,6 +1313,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1310,6 +1313,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
if not return_dict: if not return_dict:
return (sample,) return (sample,)
......
...@@ -26,9 +26,11 @@ from .constants import ( ...@@ -26,9 +26,11 @@ from .constants import (
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
HF_MODULES_CACHE, HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME, ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
USE_PEFT_BACKEND,
WEIGHTS_NAME, WEIGHTS_NAME,
) )
from .deprecation_utils import deprecate from .deprecation_utils import deprecate
...@@ -86,6 +88,7 @@ from .loading_utils import load_image ...@@ -86,6 +88,7 @@ from .loading_utils import load_image
from .logging import get_logger from .logging import get_logger
from .outputs import BaseOutput from .outputs import BaseOutput
from .peft_utils import ( from .peft_utils import (
check_peft_version,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
recurse_remove_peft_layers, recurse_remove_peft_layers,
...@@ -95,7 +98,11 @@ from .peft_utils import ( ...@@ -95,7 +98,11 @@ from .peft_utils import (
unscale_lora_layers, unscale_lora_layers,
) )
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft from .state_dict_utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
)
logger = get_logger(__name__) logger = get_logger(__name__)
......
...@@ -11,13 +11,19 @@ ...@@ -11,13 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib
import os import os
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
from .import_utils import is_peft_available, is_transformers_available
default_cache_path = HUGGINGFACE_HUB_CACHE default_cache_path = HUGGINGFACE_HUB_CACHE
MIN_PEFT_VERSION = "0.5.0"
CONFIG_NAME = "config.json" CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin" WEIGHTS_NAME = "diffusion_pytorch_model.bin"
...@@ -30,3 +36,16 @@ DIFFUSERS_CACHE = default_cache_path ...@@ -30,3 +36,16 @@ DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(MIN_PEFT_VERSION)
_required_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
...@@ -15,8 +15,11 @@ ...@@ -15,8 +15,11 @@
PEFT utilities: Utilities related to peft library PEFT utilities: Utilities related to peft library
""" """
import collections import collections
import importlib
from .import_utils import is_torch_available from packaging import version
from .import_utils import is_peft_available, is_torch_available
def recurse_remove_peft_layers(model): def recurse_remove_peft_layers(model):
...@@ -53,7 +56,6 @@ def recurse_remove_peft_layers(model): ...@@ -53,7 +56,6 @@ def recurse_remove_peft_layers(model):
module.padding, module.padding,
module.dilation, module.dilation,
module.groups, module.groups,
module.bias,
).to(module.weight.device) ).to(module.weight.device)
new_module.weight = module.weight new_module.weight = module.weight
...@@ -106,10 +108,11 @@ def unscale_lora_layers(model): ...@@ -106,10 +108,11 @@ def unscale_lora_layers(model):
module.unscale_layer() module.unscale_layer()
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
rank_pattern = {} rank_pattern = {}
alpha_pattern = {} alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0] r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1: if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times # get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0] r = collections.Counter(rank_dict.values()).most_common()[0][0]
...@@ -118,13 +121,22 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict): ...@@ -118,13 +121,22 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items())) rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()} rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1: if network_alpha_dict is not None:
# get the alpha occuring the most number of times if len(set(network_alpha_dict.values())) > 1:
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0] # get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items())) # for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()} alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
if is_unet:
alpha_pattern = {
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
for k, v in alpha_pattern.items()
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
else:
lora_alpha = set(network_alpha_dict.values()).pop()
# layer names without the Diffusers specific # layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()}) target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
...@@ -155,9 +167,9 @@ def set_adapter_layers(model, enabled=True): ...@@ -155,9 +167,9 @@ def set_adapter_layers(model, enabled=True):
if isinstance(module, BaseTunerLayer): if isinstance(module, BaseTunerLayer):
# The recent version of PEFT needs to call `enable_adapters` instead # The recent version of PEFT needs to call `enable_adapters` instead
if hasattr(module, "enable_adapters"): if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False) module.enable_adapters(enabled=enabled)
else: else:
module.disable_adapters = True module.disable_adapters = not enabled
def set_weights_and_activate_adapters(model, adapter_names, weights): def set_weights_and_activate_adapters(model, adapter_names, weights):
...@@ -182,3 +194,23 @@ def set_weights_and_activate_adapters(model, adapter_names, weights): ...@@ -182,3 +194,23 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_adapter(adapter_names) module.set_adapter(adapter_names)
else: else:
module.active_adapter = adapter_names module.active_adapter = adapter_names
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.
Args:
version (`str`):
The version of PEFT to check against.
"""
if not is_peft_available():
raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)
if not is_peft_version_compatible:
raise ValueError(
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)
...@@ -28,6 +28,22 @@ class StateDictType(enum.Enum): ...@@ -28,6 +28,22 @@ class StateDictType(enum.Enum):
DIFFUSERS = "diffusers" DIFFUSERS = "diffusers"
# We need to define a proper mapping for Unet since it uses different output keys than text encoder
# e.g. to_q_lora -> q_proj / to_q
UNET_TO_DIFFUSERS = {
".to_out_lora.up": ".to_out.0.lora_B",
".to_out_lora.down": ".to_out.0.lora_A",
".to_q_lora.down": ".to_q.lora_A",
".to_q_lora.up": ".to_q.lora_B",
".to_k_lora.down": ".to_k.lora_A",
".to_k_lora.up": ".to_k.lora_B",
".to_v_lora.down": ".to_v.lora_A",
".to_v_lora.up": ".to_v.lora_B",
".lora.up": ".lora_B",
".lora.down": ".lora_A",
}
DIFFUSERS_TO_PEFT = { DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B", ".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
".q_proj.lora_linear_layer.down": ".q_proj.lora_A", ".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
...@@ -50,6 +66,8 @@ DIFFUSERS_OLD_TO_PEFT = { ...@@ -50,6 +66,8 @@ DIFFUSERS_OLD_TO_PEFT = {
".to_v_lora.down": ".v_proj.lora_A", ".to_v_lora.down": ".v_proj.lora_A",
".to_out_lora.up": ".out_proj.lora_B", ".to_out_lora.up": ".out_proj.lora_B",
".to_out_lora.down": ".out_proj.lora_A", ".to_out_lora.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
} }
PEFT_TO_DIFFUSERS = { PEFT_TO_DIFFUSERS = {
...@@ -84,6 +102,10 @@ DIFFUSERS_STATE_DICT_MAPPINGS = { ...@@ -84,6 +102,10 @@ DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.PEFT: PEFT_TO_DIFFUSERS, StateDictType.PEFT: PEFT_TO_DIFFUSERS,
} }
KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
def convert_state_dict(state_dict, mapping): def convert_state_dict(state_dict, mapping):
r""" r"""
...@@ -103,6 +125,12 @@ def convert_state_dict(state_dict, mapping): ...@@ -103,6 +125,12 @@ def convert_state_dict(state_dict, mapping):
""" """
converted_state_dict = {} converted_state_dict = {}
for k, v in state_dict.items(): for k, v in state_dict.items():
# First, filter out the keys that we always want to replace
for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
if pattern in k:
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
k = k.replace(pattern, new_pattern)
for pattern in mapping.keys(): for pattern in mapping.keys():
if pattern in k: if pattern in k:
new_pattern = mapping[pattern] new_pattern = mapping[pattern]
...@@ -184,3 +212,11 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): ...@@ -184,3 +212,11 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping) return convert_state_dict(state_dict, mapping)
def convert_unet_state_dict_to_peft(state_dict):
r"""
Converts a state dict from UNet format to diffusers format - i.e. by removing some keys
"""
mapping = UNET_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
...@@ -673,6 +673,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -673,6 +673,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
@deprecate_after_peft_backend
class SDXInpaintLoraMixinTests(unittest.TestCase): class SDXInpaintLoraMixinTests(unittest.TestCase):
def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True): def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True):
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched # TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
...@@ -1387,6 +1388,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -1387,6 +1388,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth." ), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth."
@deprecate_after_peft_backend
class UNet2DConditionLoRAModelTests(unittest.TestCase): class UNet2DConditionLoRAModelTests(unittest.TestCase):
model_class = UNet2DConditionModel model_class = UNet2DConditionModel
main_input_name = "sample" main_input_name = "sample"
...@@ -1635,6 +1637,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase): ...@@ -1635,6 +1637,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
assert max_diff_off_sample < expected_max_diff assert max_diff_off_sample < expected_max_diff
@deprecate_after_peft_backend
class UNet3DConditionModelTests(unittest.TestCase): class UNet3DConditionModelTests(unittest.TestCase):
model_class = UNet3DConditionModel model_class = UNet3DConditionModel
main_input_name = "sample" main_input_name = "sample"
...@@ -1877,6 +1880,7 @@ class UNet3DConditionModelTests(unittest.TestCase): ...@@ -1877,6 +1880,7 @@ class UNet3DConditionModelTests(unittest.TestCase):
@slow @slow
@deprecate_after_peft_backend
@require_torch_gpu @require_torch_gpu
class LoraIntegrationTests(unittest.TestCase): class LoraIntegrationTests(unittest.TestCase):
def test_dreambooth_old_format(self): def test_dreambooth_old_format(self):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment