Unverified Commit 13e8fdec authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[feat] add `load_lora_adapter()` for compatible models (#9712)



* add first draft.

* fix

* updates.

* updates.

* updates

* updates

* updates.

* fix-copies

* lora constants.

* add tests

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* docstrings.

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent c10f875f
...@@ -51,6 +51,9 @@ if is_accelerate_available(): ...@@ -51,6 +51,9 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
""" """
...@@ -181,61 +184,7 @@ def _remove_text_encoder_monkey_patch(text_encoder): ...@@ -181,61 +184,7 @@ def _remove_text_encoder_monkey_patch(text_encoder):
text_encoder._hf_peft_config_loaded = None text_encoder._hf_peft_config_loaded = None
class LoraBaseMixin: def _fetch_state_dict(
"""Utility class for handling LoRAs."""
_lora_loadable_modules = []
num_fused_loras = 0
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@classmethod
def save_lora_weights(cls, **kwargs):
raise NotImplementedError("`save_lora_weights()` not implemented.")
@classmethod
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod
def _fetch_state_dict(
cls,
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
weight_name, weight_name,
use_safetensors, use_safetensors,
...@@ -248,9 +197,7 @@ class LoraBaseMixin: ...@@ -248,9 +197,7 @@ class LoraBaseMixin:
subfolder, subfolder,
user_agent, user_agent,
allow_pickle, allow_pickle,
): ):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights # Let's first try to load .safetensors weights
...@@ -262,7 +209,7 @@ class LoraBaseMixin: ...@@ -262,7 +209,7 @@ class LoraBaseMixin:
# friendliness where sometimes, it's not at all possible to automatically # friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`. # determine `weight_name`.
if weight_name is None: if weight_name is None:
weight_name = cls._best_guess_weight_name( weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
file_extension=".safetensors", file_extension=".safetensors",
local_files_only=local_files_only, local_files_only=local_files_only,
...@@ -289,7 +236,7 @@ class LoraBaseMixin: ...@@ -289,7 +236,7 @@ class LoraBaseMixin:
if model_file is None: if model_file is None:
if weight_name is None: if weight_name is None:
weight_name = cls._best_guess_weight_name( weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
) )
model_file = _get_model_file( model_file = _get_model_file(
...@@ -310,12 +257,10 @@ class LoraBaseMixin: ...@@ -310,12 +257,10 @@ class LoraBaseMixin:
return state_dict return state_dict
@classmethod
def _best_guess_weight_name(
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
def _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE: if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.") raise ValueError("When using the offline mode, you must specify a `weight_name`.")
...@@ -324,9 +269,7 @@ class LoraBaseMixin: ...@@ -324,9 +269,7 @@ class LoraBaseMixin:
if os.path.isfile(pretrained_model_name_or_path_or_dict): if os.path.isfile(pretrained_model_name_or_path_or_dict):
return return
elif os.path.isdir(pretrained_model_name_or_path_or_dict): elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [ targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
]
else: else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)] targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
...@@ -353,6 +296,71 @@ class LoraBaseMixin: ...@@ -353,6 +296,71 @@ class LoraBaseMixin:
weight_name = targeted_files[0] weight_name = targeted_files[0]
return weight_name return weight_name
class LoraBaseMixin:
"""Utility class for handling LoRAs."""
_lora_loadable_modules = []
num_fused_loras = 0
def load_lora_weights(self, **kwargs):
raise NotImplementedError("`load_lora_weights()` is not implemented.")
@classmethod
def save_lora_weights(cls, **kwargs):
raise NotImplementedError("`save_lora_weights()` not implemented.")
@classmethod
def lora_state_dict(cls, **kwargs):
raise NotImplementedError("`lora_state_dict()` is not implemented.")
@classmethod
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod
def _fetch_state_dict(cls, *args, **kwargs):
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
return _fetch_state_dict(*args, **kwargs)
@classmethod
def _best_guess_weight_name(cls, *args, **kwargs):
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
return _best_guess_weight_name(*args, **kwargs)
def unload_lora_weights(self): def unload_lora_weights(self):
""" """
Unloads the LoRA parameters. Unloads the LoRA parameters.
...@@ -725,8 +733,6 @@ class LoraBaseMixin: ...@@ -725,8 +733,6 @@ class LoraBaseMixin:
save_function: Callable, save_function: Callable,
safe_serialization: bool, safe_serialization: bool,
): ):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
......
This diff is collapsed.
...@@ -16,18 +16,32 @@ import inspect ...@@ -16,18 +16,32 @@ import inspect
from functools import partial from functools import partial
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch.nn as nn
from ..utils import ( from ..utils import (
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
check_peft_version, check_peft_version,
convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available, is_peft_available,
is_peft_version,
logging,
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .lora_base import _fetch_state_dict
from .unet_loader_utils import _maybe_expand_lora_scales from .unet_loader_utils import _maybe_expand_lora_scales
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = { _SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales, "UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales,
...@@ -53,6 +67,215 @@ class PeftAdapterMixin: ...@@ -53,6 +67,215 @@ class PeftAdapterMixin:
_hf_peft_config_loaded = False _hf_peft_config_loaded = False
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
r"""
Loads a LoRA adapter into the underlying model.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
prefix (`str`, *optional*): Prefix to filter the state dict.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
adapter_name = kwargs.pop("adapter_name", None)
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(prefix)]
if len(transformer_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def set_adapters( def set_adapters(
self, self,
adapter_names: Union[List[str], str], adapter_names: Union[List[str], str],
......
import os
import tempfile
import unittest
import torch
from diffusers.loaders.lora_base import LoraBaseMixin
class UtilityMethodDeprecationTests(unittest.TestCase):
def test_fetch_state_dict_cls_method_raises_warning(self):
state_dict = torch.nn.Linear(3, 3).state_dict()
with self.assertWarns(FutureWarning) as warning:
_ = LoraBaseMixin._fetch_state_dict(
state_dict,
weight_name=None,
use_safetensors=False,
local_files_only=True,
cache_dir=None,
force_download=False,
proxies=None,
token=None,
revision=None,
subfolder=None,
user_agent=None,
allow_pickle=None,
)
warning_message = str(warning.warnings[0].message)
assert "Using the `_fetch_state_dict()` method from" in warning_message
def test_best_guess_weight_name_cls_method_raises_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
state_dict = torch.nn.Linear(3, 3).state_dict()
torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
with self.assertWarns(FutureWarning) as warning:
_ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
warning_message = str(warning.warnings[0].message)
assert "Using the `_best_guess_weight_name()` method from" in warning_message
...@@ -1787,7 +1787,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1787,7 +1787,7 @@ class PeftLoraLoaderMixinTests:
logger = ( logger = (
logging.get_logger("diffusers.loaders.unet") logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline") else logging.get_logger("diffusers.loaders.peft")
) )
logger.setLevel(30) logger.setLevel(30)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
...@@ -1826,7 +1826,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1826,7 +1826,7 @@ class PeftLoraLoaderMixinTests:
logger = ( logger = (
logging.get_logger("diffusers.loaders.unet") logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline") else logging.get_logger("diffusers.loaders.peft")
) )
logger.setLevel(30) logger.setLevel(30)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment