Unverified Commit 2bfa55f4 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core` / `PEFT` / `LoRA`] Integrate PEFT into Unet (#5151)



* v1

* add tests and fix previous failing tests

* fix CI

* add tests + v1 `PeftLayerScaler`

* style

* add scale retrieving mechanism system

* fix CI

* up

* up

* simple approach --> not same results for some reason

* fix issues

* fix copies

* remove unneeded method

* active adapters!

* fix merge conflicts

* up

* up

* kohya - test-1

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix scale

* fix copies

* add comment

* multi adapters

* fix tests

* oops

* v1 faster loading - in progress

* Revert "v1 faster loading - in progress"

This reverts commit ac925f81321e95fc8168184c3346bf3d75404d5a.

* kohya same generation

* fix some slow tests

* peft integration features for unet lora

1. Support for Multiple ranks/alphas
2. Support for Multiple active adapters
3. Support for enabling/disabling LoRAs

* fix `get_peft_kwargs`

* Update loaders.py

* add some tests

* add unfuse tests

* fix tests

* up

* add set adapter from sourab and tests

* fix multi adapter tests

* style & quality

* style

* remove comment

* fix `adapter_name` issues

* fix unet adapter name for sdxl

* fix enabling/disabling adapters

* fix fuse / unfuse unet

* nit

* fix

* up

* fix cpu offloading

* fix another slow test

* fix another offload test

* add more tests

* all slow tests pass

* style

* fix alpha pattern for unet and text encoder

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* up

* up

* clarify comment

* comments

* change comment order

* change comment order

* stylr & quality

* Update tests/lora/test_lora_layers_peft.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* fix bugs and add tests

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/models/modeling_utils.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* refactor

* suggestion

* add break statemebt

* add compile tests

* move slow tests to peft tests as I modified them

* quality

* refactor a bit

* style

* change import

* style

* fix CI

* refactor slow tests one last time

* style

* oops

* oops

* oops

* final tweak tests

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* comments

* Apply suggestions from code review
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* remove comments

* more comments

* try

* revert

* add `safe_merge` tests

* add comment

* style, comments and run tests in fp16

* add warnings

* fix doc test

* replace with `adapter_weights`

* add `get_active_adapters()`

* expose `get_list_adapters` method

* better error message

* Apply suggestions from code review
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* style

* trigger slow lora tests

* fix tests

* maybe fix last test

* revert

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* move `MIN_PEFT_VERSION`

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* let's not use class variable

* fix few nits

* change a bit offloading logic

* check earlier

* rm unneeded block

* break long line

* return empty list

* change logic a bit and address comments

* add typehint

* remove parenthesis

* fix

* revert to fp16 in tests

* add to gpu

* revert to old test

* style

* Update src/diffusers/loaders.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* change indent

* Apply suggestions from code review

* Apply suggestions from code review

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent 9bc55e8b
......@@ -31,7 +31,14 @@ from ...models.attention_processor import (
)
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
PIL_INTERPOLATION,
USE_PEFT_BACKEND,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
......@@ -283,7 +290,7 @@ class StableDiffusionXLAdapterPipeline(
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
else:
......@@ -425,7 +432,7 @@ class StableDiffusionXLAdapterPipeline(
bs_embed * num_images_per_prompt, -1
)
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
unscale_lora_layers(self.text_encoder_2)
......
......@@ -23,7 +23,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput
......@@ -224,7 +231,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -352,7 +359,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -24,7 +24,14 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet3DConditionModel
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, replace_example_docstring, scale_lora_layers, unscale_lora_layers
from ...utils import (
USE_PEFT_BACKEND,
deprecate,
logging,
replace_example_docstring,
scale_lora_layers,
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import TextToVideoSDPipelineOutput
......@@ -286,7 +293,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -414,7 +421,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -18,7 +18,7 @@ from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL
from ...models.lora import adjust_lora_scale_text_encoder
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.outputs import BaseOutput
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
......@@ -426,7 +426,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not self.use_peft_backend:
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
else:
scale_lora_layers(self.text_encoder, lora_scale)
......@@ -554,7 +554,7 @@ class UniDiffuserPipeline(DiffusionPipeline):
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
if isinstance(self, LoraLoaderMixin) and self.use_peft_backend:
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder)
......
......@@ -31,7 +31,7 @@ from ...models.embeddings import (
)
from ...models.transformer_2d import Transformer2DModel
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import is_torch_version, logging
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import apply_freeu
......@@ -1211,6 +1211,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
# 3. down
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
......@@ -1310,6 +1313,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self)
if not return_dict:
return (sample,)
......
......@@ -26,9 +26,11 @@ from .constants import (
FLAX_WEIGHTS_NAME,
HF_MODULES_CACHE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
MIN_PEFT_VERSION,
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
USE_PEFT_BACKEND,
WEIGHTS_NAME,
)
from .deprecation_utils import deprecate
......@@ -86,6 +88,7 @@ from .loading_utils import load_image
from .logging import get_logger
from .outputs import BaseOutput
from .peft_utils import (
check_peft_version,
get_adapter_name,
get_peft_kwargs,
recurse_remove_peft_layers,
......@@ -95,7 +98,11 @@ from .peft_utils import (
unscale_lora_layers,
)
from .pil_utils import PIL_INTERPOLATION, make_image_grid, numpy_to_pil, pt_to_pil
from .state_dict_utils import convert_state_dict_to_diffusers, convert_state_dict_to_peft
from .state_dict_utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
)
logger = get_logger(__name__)
......
......@@ -11,13 +11,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE, hf_cache_home
from packaging import version
from .import_utils import is_peft_available, is_transformers_available
default_cache_path = HUGGINGFACE_HUB_CACHE
MIN_PEFT_VERSION = "0.5.0"
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "diffusion_pytorch_model.bin"
......@@ -30,3 +36,16 @@ DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
# available.
# For PEFT it is has to be greater than 0.6.0 and for transformers it has to be greater than 4.33.1.
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(MIN_PEFT_VERSION)
_required_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse("4.33")
USE_PEFT_BACKEND = _required_peft_version and _required_transformers_version
......@@ -15,8 +15,11 @@
PEFT utilities: Utilities related to peft library
"""
import collections
import importlib
from .import_utils import is_torch_available
from packaging import version
from .import_utils import is_peft_available, is_torch_available
def recurse_remove_peft_layers(model):
......@@ -53,7 +56,6 @@ def recurse_remove_peft_layers(model):
module.padding,
module.dilation,
module.groups,
module.bias,
).to(module.weight.device)
new_module.weight = module.weight
......@@ -106,10 +108,11 @@ def unscale_lora_layers(model):
module.unscale_layer()
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
rank_pattern = {}
alpha_pattern = {}
r = lora_alpha = list(rank_dict.values())[0]
if len(set(rank_dict.values())) > 1:
# get the rank occuring the most number of times
r = collections.Counter(rank_dict.values()).most_common()[0][0]
......@@ -118,13 +121,22 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict):
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
if network_alpha_dict is not None and len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
if network_alpha_dict is not None:
if len(set(network_alpha_dict.values())) > 1:
# get the alpha occuring the most number of times
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
# for modules with alpha different from the most occuring alpha, add it to the `alpha_pattern`
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
if is_unet:
alpha_pattern = {
".".join(k.split(".lora_A.")[0].split(".")).replace(".alpha", ""): v
for k, v in alpha_pattern.items()
}
else:
alpha_pattern = {".".join(k.split(".down.")[0].split(".")[:-1]): v for k, v in alpha_pattern.items()}
else:
lora_alpha = set(network_alpha_dict.values()).pop()
# layer names without the Diffusers specific
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
......@@ -155,9 +167,9 @@ def set_adapter_layers(model, enabled=True):
if isinstance(module, BaseTunerLayer):
# The recent version of PEFT needs to call `enable_adapters` instead
if hasattr(module, "enable_adapters"):
module.enable_adapters(enabled=False)
module.enable_adapters(enabled=enabled)
else:
module.disable_adapters = True
module.disable_adapters = not enabled
def set_weights_and_activate_adapters(model, adapter_names, weights):
......@@ -182,3 +194,23 @@ def set_weights_and_activate_adapters(model, adapter_names, weights):
module.set_adapter(adapter_names)
else:
module.active_adapter = adapter_names
def check_peft_version(min_version: str) -> None:
r"""
Checks if the version of PEFT is compatible.
Args:
version (`str`):
The version of PEFT to check against.
"""
if not is_peft_available():
raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) > version.parse(min_version)
if not is_peft_version_compatible:
raise ValueError(
f"The version of PEFT you are using is not compatible, please use a version that is greater"
f" than {min_version}"
)
......@@ -28,6 +28,22 @@ class StateDictType(enum.Enum):
DIFFUSERS = "diffusers"
# We need to define a proper mapping for Unet since it uses different output keys than text encoder
# e.g. to_q_lora -> q_proj / to_q
UNET_TO_DIFFUSERS = {
".to_out_lora.up": ".to_out.0.lora_B",
".to_out_lora.down": ".to_out.0.lora_A",
".to_q_lora.down": ".to_q.lora_A",
".to_q_lora.up": ".to_q.lora_B",
".to_k_lora.down": ".to_k.lora_A",
".to_k_lora.up": ".to_k.lora_B",
".to_v_lora.down": ".to_v.lora_A",
".to_v_lora.up": ".to_v.lora_B",
".lora.up": ".lora_B",
".lora.down": ".lora_A",
}
DIFFUSERS_TO_PEFT = {
".q_proj.lora_linear_layer.up": ".q_proj.lora_B",
".q_proj.lora_linear_layer.down": ".q_proj.lora_A",
......@@ -50,6 +66,8 @@ DIFFUSERS_OLD_TO_PEFT = {
".to_v_lora.down": ".v_proj.lora_A",
".to_out_lora.up": ".out_proj.lora_B",
".to_out_lora.down": ".out_proj.lora_A",
".lora_linear_layer.up": ".lora_B",
".lora_linear_layer.down": ".lora_A",
}
PEFT_TO_DIFFUSERS = {
......@@ -84,6 +102,10 @@ DIFFUSERS_STATE_DICT_MAPPINGS = {
StateDictType.PEFT: PEFT_TO_DIFFUSERS,
}
KEYS_TO_ALWAYS_REPLACE = {
".processor.": ".",
}
def convert_state_dict(state_dict, mapping):
r"""
......@@ -103,6 +125,12 @@ def convert_state_dict(state_dict, mapping):
"""
converted_state_dict = {}
for k, v in state_dict.items():
# First, filter out the keys that we always want to replace
for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
if pattern in k:
new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
k = k.replace(pattern, new_pattern)
for pattern in mapping.keys():
if pattern in k:
new_pattern = mapping[pattern]
......@@ -184,3 +212,11 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
return convert_state_dict(state_dict, mapping)
def convert_unet_state_dict_to_peft(state_dict):
r"""
Converts a state dict from UNet format to diffusers format - i.e. by removing some keys
"""
mapping = UNET_TO_DIFFUSERS
return convert_state_dict(state_dict, mapping)
......@@ -673,6 +673,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
@deprecate_after_peft_backend
class SDXInpaintLoraMixinTests(unittest.TestCase):
def get_dummy_inputs(self, device, seed=0, img_res=64, output_pil=True):
# TODO: use tensor inputs instead of PIL, this is here just to leave the old expected_slices untouched
......@@ -1387,6 +1388,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
), "The pipeline was serialized with LoRA parameters fused inside of the respected modules. The loaded pipeline should yield proper outputs, henceforth."
@deprecate_after_peft_backend
class UNet2DConditionLoRAModelTests(unittest.TestCase):
model_class = UNet2DConditionModel
main_input_name = "sample"
......@@ -1635,6 +1637,7 @@ class UNet2DConditionLoRAModelTests(unittest.TestCase):
assert max_diff_off_sample < expected_max_diff
@deprecate_after_peft_backend
class UNet3DConditionModelTests(unittest.TestCase):
model_class = UNet3DConditionModel
main_input_name = "sample"
......@@ -1877,6 +1880,7 @@ class UNet3DConditionModelTests(unittest.TestCase):
@slow
@deprecate_after_peft_backend
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):
def test_dreambooth_old_format(self):
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment