Unverified Commit 13e8fdec authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[feat] add `load_lora_adapter()` for compatible models (#9712)



* add first draft.

* fix

* updates.

* updates.

* updates

* updates

* updates.

* fix-copies

* lora constants.

* add tests

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* docstrings.

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
parent c10f875f
...@@ -51,6 +51,9 @@ if is_accelerate_available(): ...@@ -51,6 +51,9 @@ if is_accelerate_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
""" """
...@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder): ...@@ -181,6 +184,119 @@ def _remove_text_encoder_monkey_patch(text_encoder):
text_encoder._hf_peft_config_loaded = None text_encoder._hf_peft_config_loaded = None
def _fetch_state_dict(
pretrained_model_name_or_path_or_dict,
weight_name,
use_safetensors,
local_files_only,
cache_dir,
force_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
if weight_name is None:
weight_name = _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
def _best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
):
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
# "scheduler" does not correspond to a LoRA checkpoint.
# "optimizer" does not correspond to a LoRA checkpoint
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
targeted_files = list(
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
class LoraBaseMixin: class LoraBaseMixin:
"""Utility class for handling LoRAs.""" """Utility class for handling LoRAs."""
...@@ -234,124 +350,16 @@ class LoraBaseMixin: ...@@ -234,124 +350,16 @@ class LoraBaseMixin:
return (is_model_cpu_offload, is_sequential_cpu_offload) return (is_model_cpu_offload, is_sequential_cpu_offload)
@classmethod @classmethod
def _fetch_state_dict( def _fetch_state_dict(cls, *args, **kwargs):
cls, deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
pretrained_model_name_or_path_or_dict, deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
weight_name, return _fetch_state_dict(*args, **kwargs)
use_safetensors,
local_files_only,
cache_dir,
force_download,
proxies,
token,
revision,
subfolder,
user_agent,
allow_pickle,
):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
# Here we're relaxing the loading check to enable more Inference API
# friendliness where sometimes, it's not at all possible to automatically
# determine `weight_name`.
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict,
file_extension=".safetensors",
local_files_only=local_files_only,
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
pass
if model_file is None:
if weight_name is None:
weight_name = cls._best_guess_weight_name(
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
)
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
@classmethod @classmethod
def _best_guess_weight_name( def _best_guess_weight_name(cls, *args, **kwargs):
cls, pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
): deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE return _best_guess_weight_name(*args, **kwargs)
if local_files_only or HF_HUB_OFFLINE:
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
targeted_files = []
if os.path.isfile(pretrained_model_name_or_path_or_dict):
return
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
targeted_files = [
f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)
]
else:
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
if len(targeted_files) == 0:
return
# "scheduler" does not correspond to a LoRA checkpoint.
# "optimizer" does not correspond to a LoRA checkpoint
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
targeted_files = list(
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
)
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
if len(targeted_files) > 1:
raise ValueError(
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
)
weight_name = targeted_files[0]
return weight_name
def unload_lora_weights(self): def unload_lora_weights(self):
""" """
...@@ -725,8 +733,6 @@ class LoraBaseMixin: ...@@ -725,8 +733,6 @@ class LoraBaseMixin:
save_function: Callable, save_function: Callable,
safe_serialization: bool, safe_serialization: bool,
): ):
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
......
...@@ -21,7 +21,6 @@ from ..utils import ( ...@@ -21,7 +21,6 @@ from ..utils import (
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
convert_state_dict_to_diffusers, convert_state_dict_to_diffusers,
convert_state_dict_to_peft, convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
deprecate, deprecate,
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
...@@ -33,7 +32,7 @@ from ..utils import ( ...@@ -33,7 +32,7 @@ from ..utils import (
logging, logging,
scale_lora_layers, scale_lora_layers,
) )
from .lora_base import LoraBaseMixin from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin, _fetch_state_dict # noqa
from .lora_conversion_utils import ( from .lora_conversion_utils import (
_convert_kohya_flux_lora_to_diffusers, _convert_kohya_flux_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers, _convert_non_diffusers_lora_to_diffusers,
...@@ -62,9 +61,6 @@ TEXT_ENCODER_NAME = "text_encoder" ...@@ -62,9 +61,6 @@ TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet" UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer" TRANSFORMER_NAME = "transformer"
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
class StableDiffusionLoraLoaderMixin(LoraBaseMixin): class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
r""" r"""
...@@ -222,7 +218,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -222,7 +218,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -282,7 +278,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -282,7 +278,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -341,7 +339,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -341,7 +339,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -601,7 +601,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -601,7 +601,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
""" """
...@@ -744,7 +746,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -744,7 +746,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -805,7 +807,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -805,7 +807,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading only loading the pretrained LoRA weights and not initializing the random weights. low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -865,7 +869,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -865,7 +869,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -1182,7 +1188,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1182,7 +1188,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -1226,7 +1232,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1226,7 +1232,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
""" """
...@@ -1250,13 +1258,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1250,13 +1258,17 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer( transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
state_dict, if len(transformer_state_dict) > 0:
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, self.load_lora_into_transformer(
adapter_name=adapter_name, state_dict,
_pipeline=self, transformer=getattr(self, self.transformer_name)
low_cpu_mem_usage=low_cpu_mem_usage, if not hasattr(self, "transformer")
) else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
...@@ -1301,94 +1313,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1301,94 +1313,24 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
) )
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict # Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
keys = list(state_dict.keys()) transformer.load_lora_adapter(
state_dict,
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] network_alphas=None,
state_dict = { adapter_name=adapter_name,
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys _pipeline=_pipeline,
} low_cpu_mem_usage=low_cpu_mem_usage,
)
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
...@@ -1424,7 +1366,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1424,7 +1366,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -1742,7 +1686,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1742,7 +1686,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -1819,7 +1763,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1819,7 +1763,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -1843,14 +1789,18 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1843,14 +1789,18 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if not is_correct_format: if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.") raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer( transformer_state_dict = {k: v for k, v in state_dict.items() if "transformer." in k}
state_dict, if len(transformer_state_dict) > 0:
network_alphas=network_alphas, self.load_lora_into_transformer(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, state_dict,
adapter_name=adapter_name, network_alphas=network_alphas,
_pipeline=self, transformer=getattr(self, self.transformer_name)
low_cpu_mem_usage=low_cpu_mem_usage, if not hasattr(self, "transformer")
) else self.transformer,
adapter_name=adapter_name,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0: if len(text_encoder_state_dict) > 0:
...@@ -1881,104 +1831,32 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1881,104 +1831,32 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
The value of the network alpha used for stable learning and preventing underflow. This value has the The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
transformer (`SD3Transformer2DModel`): transformer (`FluxTransformer2DModel`):
The Transformer model to load the LoRA layers into. The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError( raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
) )
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict # Load the layers corresponding to transformer.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] if transformer_present:
state_dict = { logger.info(f"Loading {cls.transformer_name}.")
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys transformer.load_lora_adapter(
} state_dict,
network_alphas=network_alphas,
if len(state_dict.keys()) > 0: adapter_name=adapter_name,
# check with first key if is not in peft format _pipeline=_pipeline,
first_key = next(iter(state_dict.keys())) low_cpu_mem_usage=low_cpu_mem_usage,
if "lora_A" not in first_key: )
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
prefix = cls.transformer_name
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
...@@ -2014,7 +1892,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2014,7 +1892,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -2242,7 +2122,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2242,7 +2122,10 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
text_encoder_name = TEXT_ENCODER_NAME text_encoder_name = TEXT_ENCODER_NAME
@classmethod @classmethod
def load_lora_into_transformer(cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None): # Copied from diffusers.loaders.lora_pipeline.FluxLoraLoaderMixin.load_lora_into_transformer with FluxTransformer2DModel->UVit2DModel
def load_lora_into_transformer(
cls, state_dict, network_alphas, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -2255,93 +2138,32 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2255,93 +2138,32 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
The value of the network alpha used for stable learning and preventing underflow. This value has the The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning). link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
unet (`UNet2DConditionModel`): transformer (`UVit2DModel`):
The UNet model to load the LoRA layers into. The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError("PEFT backend is required for this method.") raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict )
# Load the layers corresponding to transformer.
keys = list(state_dict.keys()) keys = list(state_dict.keys())
transformer_present = any(key.startswith(cls.transformer_name) for key in keys)
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] if transformer_present:
state_dict = { logger.info(f"Loading {cls.transformer_name}.")
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys transformer.load_lora_adapter(
} state_dict,
network_alphas=network_alphas,
if network_alphas is not None: adapter_name=adapter_name,
alpha_keys = [k for k in network_alphas.keys() if k.startswith(cls.transformer_name)] _pipeline=_pipeline,
network_alphas = { low_cpu_mem_usage=low_cpu_mem_usage,
k.replace(f"{cls.transformer_name}.", ""): v for k, v in network_alphas.items() if k in alpha_keys )
}
if len(state_dict.keys()) > 0:
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_text_encoder
...@@ -2377,7 +2199,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2377,7 +2199,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -2619,7 +2443,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2619,7 +2443,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
"framework": "pytorch", "framework": "pytorch",
} }
state_dict = cls._fetch_state_dict( state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -2658,7 +2482,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2658,7 +2482,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
""" """
...@@ -2691,7 +2517,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2691,7 +2517,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
) )
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False
): ):
...@@ -2703,99 +2529,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2703,99 +2529,29 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
A standard state dict containing the lora layer parameters. The keys can either be indexed directly A standard state dict containing the lora layer parameters. The keys can either be indexed directly
into the unet or prefixed with an additional `unet` which can be used to distinguish between text into the unet or prefixed with an additional `unet` which can be used to distinguish between text
encoder lora layers. encoder lora layers.
transformer (`SD3Transformer2DModel`): transformer (`CogVideoXTransformer3DModel`):
The Transformer model to load the LoRA layers into. The Transformer model to load the LoRA layers into.
adapter_name (`str`, *optional*): adapter_name (`str`, *optional*):
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
`default_{i}` where i is the total number of adapters being loaded. `default_{i}` where i is the total number of adapters being loaded.
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights.: low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
) )
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict # Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
keys = list(state_dict.keys()) transformer.load_lora_adapter(
state_dict,
transformer_keys = [k for k in keys if k.startswith(cls.transformer_name)] network_alphas=None,
state_dict = { adapter_name=adapter_name,
k.replace(f"{cls.transformer_name}.", ""): v for k, v in state_dict.items() if k in transformer_keys _pipeline=_pipeline,
} low_cpu_mem_usage=low_cpu_mem_usage,
)
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(transformer, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=None, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(transformer)
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = cls._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, transformer, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(transformer, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@classmethod @classmethod
# Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder # Adapted from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.save_lora_weights without support for text encoder
......
...@@ -16,18 +16,32 @@ import inspect ...@@ -16,18 +16,32 @@ import inspect
from functools import partial from functools import partial
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
import torch.nn as nn
from ..utils import ( from ..utils import (
MIN_PEFT_VERSION, MIN_PEFT_VERSION,
USE_PEFT_BACKEND, USE_PEFT_BACKEND,
check_peft_version, check_peft_version,
convert_unet_state_dict_to_peft,
delete_adapter_layers, delete_adapter_layers,
get_adapter_name,
get_peft_kwargs,
is_accelerate_available,
is_peft_available, is_peft_available,
is_peft_version,
logging,
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from .lora_base import _fetch_state_dict
from .unet_loader_utils import _maybe_expand_lora_scales from .unet_loader_utils import _maybe_expand_lora_scales
if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = { _SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales, "UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales, "UNetMotionModel": _maybe_expand_lora_scales,
...@@ -53,6 +67,215 @@ class PeftAdapterMixin: ...@@ -53,6 +67,215 @@ class PeftAdapterMixin:
_hf_peft_config_loaded = False _hf_peft_config_loaded = False
@classmethod
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
def _optionally_disable_offloading(cls, _pipeline):
"""
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
Args:
_pipeline (`DiffusionPipeline`):
The pipeline to disable offloading for.
Returns:
tuple:
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
"""
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload:
is_sequential_cpu_offload = (
isinstance(component._hf_hook, AlignDevicesHook)
or hasattr(component._hf_hook, "hooks")
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
return (is_model_cpu_offload, is_sequential_cpu_offload)
def load_lora_adapter(self, pretrained_model_name_or_path_or_dict, prefix="transformer", **kwargs):
r"""
Loads a LoRA adapter into the underlying model.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
prefix (`str`, *optional*): Prefix to filter the state dict.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
network_alphas (`Dict[str, float]`):
The value of the network alpha used for stable learning and preventing underflow. This value has the
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
low_cpu_mem_usage (`bool`, *optional*):
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
weights.
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
adapter_name = kwargs.pop("adapter_name", None)
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
keys = list(state_dict.keys())
transformer_keys = [k for k in keys if k.startswith(prefix)]
if len(transformer_keys) > 0:
state_dict = {k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in transformer_keys}
if len(state_dict.keys()) > 0:
# check with first key if is not in peft format
first_key = next(iter(state_dict.keys()))
if "lora_A" not in first_key:
state_dict = convert_unet_state_dict_to_peft(state_dict)
if adapter_name in getattr(self, "peft_config", {}):
raise ValueError(
f"Adapter name {adapter_name} already in use in the transformer - please select a new adapter name."
)
rank = {}
for key, val in state_dict.items():
if "lora_B" in key:
rank[key] = val.shape[1]
if network_alphas is not None and len(network_alphas) >= 1:
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"] and is_peft_version("<", "0.9.0"):
raise ValueError(
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
)
else:
lora_config_kwargs.pop("use_dora")
lora_config = LoraConfig(**lora_config_kwargs)
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to `_pipeline`.
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
# otherwise loading LoRA weights will lead to an error
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
peft_kwargs = {}
if is_peft_version(">=", "0.13.1"):
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name, **peft_kwargs)
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name, **peft_kwargs)
warn_msg = ""
if incompatible_keys is not None:
# Check only for unexpected keys.
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
if unexpected_keys:
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
if lora_unexpected_keys:
warn_msg = (
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
f" {', '.join(lora_unexpected_keys)}. "
)
# Filter missing keys specific to the current adapter.
missing_keys = getattr(incompatible_keys, "missing_keys", None)
if missing_keys:
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
if lora_missing_keys:
warn_msg += (
f"Loading adapter weights from state_dict led to missing keys in the model:"
f" {', '.join(lora_missing_keys)}."
)
if warn_msg:
logger.warning(warn_msg)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def set_adapters( def set_adapters(
self, self,
adapter_names: Union[List[str], str], adapter_names: Union[List[str], str],
......
import os
import tempfile
import unittest
import torch
from diffusers.loaders.lora_base import LoraBaseMixin
class UtilityMethodDeprecationTests(unittest.TestCase):
def test_fetch_state_dict_cls_method_raises_warning(self):
state_dict = torch.nn.Linear(3, 3).state_dict()
with self.assertWarns(FutureWarning) as warning:
_ = LoraBaseMixin._fetch_state_dict(
state_dict,
weight_name=None,
use_safetensors=False,
local_files_only=True,
cache_dir=None,
force_download=False,
proxies=None,
token=None,
revision=None,
subfolder=None,
user_agent=None,
allow_pickle=None,
)
warning_message = str(warning.warnings[0].message)
assert "Using the `_fetch_state_dict()` method from" in warning_message
def test_best_guess_weight_name_cls_method_raises_warning(self):
with tempfile.TemporaryDirectory() as tmpdir:
state_dict = torch.nn.Linear(3, 3).state_dict()
torch.save(state_dict, os.path.join(tmpdir, "pytorch_lora_weights.bin"))
with self.assertWarns(FutureWarning) as warning:
_ = LoraBaseMixin._best_guess_weight_name(pretrained_model_name_or_path_or_dict=tmpdir)
warning_message = str(warning.warnings[0].message)
assert "Using the `_best_guess_weight_name()` method from" in warning_message
...@@ -1787,7 +1787,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1787,7 +1787,7 @@ class PeftLoraLoaderMixinTests:
logger = ( logger = (
logging.get_logger("diffusers.loaders.unet") logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline") else logging.get_logger("diffusers.loaders.peft")
) )
logger.setLevel(30) logger.setLevel(30)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
...@@ -1826,7 +1826,7 @@ class PeftLoraLoaderMixinTests: ...@@ -1826,7 +1826,7 @@ class PeftLoraLoaderMixinTests:
logger = ( logger = (
logging.get_logger("diffusers.loaders.unet") logging.get_logger("diffusers.loaders.unet")
if self.unet_kwargs is not None if self.unet_kwargs is not None
else logging.get_logger("diffusers.loaders.lora_pipeline") else logging.get_logger("diffusers.loaders.peft")
) )
logger.setLevel(30) logger.setLevel(30)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment