Unverified Commit 368958df authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] parse metadata from LoRA and save metadata (#11324)



* feat: parse metadata from lora state dicts.

* tests

* fix tests

* key renaming

* fix

* smol update

* smol updates

* load metadata.

* automatically save metadata in save_lora_adapter.

* propagate changes.

* changes

* add test to models too.

* tigher tests.

* updates

* fixes

* rename tests.

* sorted.

* Update src/diffusers/loaders/lora_base.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* review suggestions.

* removeprefix.

* propagate changes.

* fix-copies

* sd

* docs.

* fixes

* get review ready.

* one more test to catch error.

* change to a different approach.

* fix-copies.

* todo

* sd3

* update

* revert changes in get_peft_kwargs.

* update

* fixes

* fixes

* simplify _load_sft_state_dict_metadata

* update

* style fix

* uipdate

* update

* update

* empty commit

* _pack_dict_with_prefix

* update

* TODO 1.

* todo: 2.

* todo: 3.

* update

* update

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* reraise.

* move argument.

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent e52ceae3
...@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline( ...@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
weights_name=weight_name, weights_name=weight_name,
......
...@@ -159,10 +159,7 @@ class IPAdapterMixin: ...@@ -159,10 +159,7 @@ class IPAdapterMixin:
" `low_cpu_mem_usage=False`." " `low_cpu_mem_usage=False`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = [] state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder pretrained_model_name_or_path_or_dict, weight_name, subfolder
...@@ -465,10 +462,7 @@ class FluxIPAdapterMixin: ...@@ -465,10 +462,7 @@ class FluxIPAdapterMixin:
" `low_cpu_mem_usage=False`." " `low_cpu_mem_usage=False`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = [] state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder pretrained_model_name_or_path_or_dict, weight_name, subfolder
...@@ -750,10 +744,7 @@ class SD3IPAdapterMixin: ...@@ -750,10 +744,7 @@ class SD3IPAdapterMixin:
" `low_cpu_mem_usage=False`." " `low_cpu_mem_usage=False`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file( model_file = _get_model_file(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
import inspect import inspect
import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
...@@ -45,6 +46,7 @@ from ..utils import ( ...@@ -45,6 +46,7 @@ from ..utils import (
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
if is_transformers_available(): if is_transformers_available():
...@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__) ...@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
...@@ -206,6 +209,7 @@ def _fetch_state_dict( ...@@ -206,6 +209,7 @@ def _fetch_state_dict(
subfolder, subfolder,
user_agent, user_agent,
allow_pickle, allow_pickle,
metadata=None,
): ):
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
...@@ -236,11 +240,14 @@ def _fetch_state_dict( ...@@ -236,11 +240,14 @@ def _fetch_state_dict(
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = safetensors.torch.load_file(model_file, device="cpu")
metadata = _load_sft_state_dict_metadata(model_file)
except (IOError, safetensors.SafetensorError) as e: except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle: if not allow_pickle:
raise e raise e
# try loading non-safetensors weights # try loading non-safetensors weights
model_file = None model_file = None
metadata = None
pass pass
if model_file is None: if model_file is None:
...@@ -261,10 +268,11 @@ def _fetch_state_dict( ...@@ -261,10 +268,11 @@ def _fetch_state_dict(
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file)
metadata = None
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
return state_dict return state_dict, metadata
def _best_guess_weight_name( def _best_guess_weight_name(
...@@ -306,6 +314,11 @@ def _best_guess_weight_name( ...@@ -306,6 +314,11 @@ def _best_guess_weight_name(
return weight_name return weight_name
def _pack_dict_with_prefix(state_dict, prefix):
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
return sd_with_prefix
def _load_lora_into_text_encoder( def _load_lora_into_text_encoder(
state_dict, state_dict,
network_alphas, network_alphas,
...@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder( ...@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
if network_alphas and metadata:
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
peft_kwargs = {} peft_kwargs = {}
if low_cpu_mem_usage: if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"): if not is_peft_version(">=", "0.13.1"):
...@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder( ...@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None: if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0: if len(state_dict) > 0:
logger.info(f"Loading {prefix}.") logger.info(f"Loading {prefix}.")
...@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder( ...@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
...@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder( ...@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
if is_peft_version("<=", "0.13.2"): if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias") lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs) try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name # adapter_name
if adapter_name is None: if adapter_name is None:
...@@ -889,8 +914,7 @@ class LoraBaseMixin: ...@@ -889,8 +914,7 @@ class LoraBaseMixin:
@staticmethod @staticmethod
def pack_weights(layers, prefix): def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} return _pack_dict_with_prefix(layers_weights, prefix)
return layers_state_dict
@staticmethod @staticmethod
def write_lora_layers( def write_lora_layers(
...@@ -900,16 +924,32 @@ class LoraBaseMixin: ...@@ -900,16 +924,32 @@ class LoraBaseMixin:
weight_name: str, weight_name: str,
save_function: Callable, save_function: Callable,
safe_serialization: bool, safe_serialization: bool,
lora_adapter_metadata: Optional[dict] = None,
): ):
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
if lora_adapter_metadata and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
if save_function is None: if save_function is None:
if safe_serialization: if safe_serialization:
def save_function(weights, filename): def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) # Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
lora_adapter_metadata, indent=2, sort_keys=True
)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else: else:
save_function = torch.save save_function = torch.save
......
This diff is collapsed.
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
import json
import os import os
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
...@@ -185,6 +186,7 @@ class PeftAdapterMixin: ...@@ -185,6 +186,7 @@ class PeftAdapterMixin:
Note that hotswapping adapters of the text encoder is not yet supported. There are some further Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here: limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata: TODO
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
...@@ -202,6 +204,7 @@ class PeftAdapterMixin: ...@@ -202,6 +204,7 @@ class PeftAdapterMixin:
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None) _pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
metadata = kwargs.pop("metadata", None)
allow_pickle = False allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
...@@ -209,12 +212,9 @@ class PeftAdapterMixin: ...@@ -209,12 +212,9 @@ class PeftAdapterMixin:
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -227,12 +227,17 @@ class PeftAdapterMixin: ...@@ -227,12 +227,17 @@ class PeftAdapterMixin:
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
metadata=metadata,
) )
if network_alphas is not None and prefix is None: if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.") raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if network_alphas and metadata:
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
if prefix is not None: if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0: if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap: if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
...@@ -267,7 +272,12 @@ class PeftAdapterMixin: ...@@ -267,7 +272,12 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs) _maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
...@@ -290,7 +300,11 @@ class PeftAdapterMixin: ...@@ -290,7 +300,11 @@ class PeftAdapterMixin:
if is_peft_version("<=", "0.13.2"): if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias") lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs) try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name # adapter_name
if adapter_name is None: if adapter_name is None:
adapter_name = get_adapter_name(self) adapter_name = get_adapter_name(self)
...@@ -445,17 +459,13 @@ class PeftAdapterMixin: ...@@ -445,17 +459,13 @@ class PeftAdapterMixin:
underlying model has multiple adapters loaded. underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`): upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization. Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
""" """
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if adapter_name is None: if adapter_name is None:
adapter_name = get_adapter_name(self) adapter_name = get_adapter_name(self)
...@@ -463,6 +473,8 @@ class PeftAdapterMixin: ...@@ -463,6 +473,8 @@ class PeftAdapterMixin:
if adapter_name not in getattr(self, "peft_config", {}): if adapter_name not in getattr(self, "peft_config", {}):
raise ValueError(f"Adapter name {adapter_name} not found in the model.") raise ValueError(f"Adapter name {adapter_name} not found in the model.")
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
lora_layers_to_save = get_peft_model_state_dict( lora_layers_to_save = get_peft_model_state_dict(
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
) )
...@@ -472,7 +484,15 @@ class PeftAdapterMixin: ...@@ -472,7 +484,15 @@ class PeftAdapterMixin:
if safe_serialization: if safe_serialization:
def save_function(weights, filename): def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) # Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else: else:
save_function = torch.save save_function = torch.save
...@@ -485,7 +505,6 @@ class PeftAdapterMixin: ...@@ -485,7 +505,6 @@ class PeftAdapterMixin:
else: else:
weight_name = LORA_WEIGHT_NAME weight_name = LORA_WEIGHT_NAME
# TODO: we could consider saving the `peft_config` as well.
save_path = Path(save_directory, weight_name).as_posix() save_path = Path(save_directory, weight_name).as_posix()
save_function(lora_layers_to_save, save_path) save_function(lora_layers_to_save, save_path)
logger.info(f"Model weights saved in {save_path}") logger.info(f"Model weights saved in {save_path}")
......
...@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin: ...@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
......
...@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily ...@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
""" """
import enum import enum
import json
from .import_utils import is_torch_available from .import_utils import is_torch_available
from .logging import get_logger from .logging import get_logger
...@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None): ...@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None):
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
return all(torch.all(param == 0).item() for param in state_dict.values()) return all(torch.all(param == 0).item() for param in state_dict.values())
def _load_sft_state_dict_metadata(model_file: str):
import safetensors.torch
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
...@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b): ...@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b):
return distance return distance
def check_if_dicts_are_equal(dict1, dict2):
dict1, dict2 = dict1.copy(), dict2.copy()
for key, value in dict1.items():
if isinstance(value, set):
dict1[key] = sorted(value)
for key, value in dict2.items():
if isinstance(value, set):
dict2[key] = sorted(value)
for key in dict1:
if key not in dict2:
return False
if dict1[key] != dict2[key]:
return False
for key in dict2:
if key not in dict1:
return False
return True
def print_tensor_test( def print_tensor_test(
tensor, tensor,
limit_to_slices=None, limit_to_slices=None,
......
...@@ -24,11 +24,7 @@ from diffusers import ( ...@@ -24,11 +24,7 @@ from diffusers import (
WanPipeline, WanPipeline,
WanTransformer3DModel, WanTransformer3DModel,
) )
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
floats_tensor,
require_peft_backend,
skip_mps,
)
sys.path.append(".") sys.path.append(".")
......
...@@ -22,6 +22,7 @@ from itertools import product ...@@ -22,6 +22,7 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from parameterized import parameterized
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -33,6 +34,7 @@ from diffusers.utils import logging ...@@ -33,6 +34,7 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
check_if_dicts_are_equal,
floats_tensor, floats_tensor,
is_torch_version, is_torch_version,
require_peft_backend, require_peft_backend,
...@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool: ...@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool:
return False return False
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
extracted = {
k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
}
check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
def initialize_dummy_state_dict(state_dict): def initialize_dummy_state_dict(state_dict):
if not all(v.device.type == "meta" for _, v in state_dict.items()): if not all(v.device.type == "meta" for _, v in state_dict.items()):
raise ValueError("`state_dict` has non-meta values.") raise ValueError("`state_dict` has non-meta values.")
...@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests: ...@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False): def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs: if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders: if self.has_two_text_encoders and self.has_three_text_encoders:
...@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests: ...@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4 rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
torch.manual_seed(0) torch.manual_seed(0)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests: ...@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules, target_modules=self.text_encoder_target_modules,
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
...@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests: ...@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests:
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=lora_alpha,
target_modules=self.denoiser_target_modules, target_modules=self.denoiser_target_modules,
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
...@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests: ...@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts return state_dicts
def _get_lora_adapter_metadata(self, modules_to_save):
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None:
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
return metadatas
def _get_modules_to_save(self, pipe, has_denoiser=False): def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {} modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules lora_loadable_modules = self.pipeline_class._lora_loadable_modules
...@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests: ...@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0] pipe(**inputs, generator=torch.manual_seed(0))[0]
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
if len(out) == 3:
_, _, parsed_metadata = out
elif len(out) == 2:
_, parsed_metadata = out
denoiser_key = (
f"{self.pipeline_class.transformer_name}"
if self.transformer_kwargs is not None
else f"{self.pipeline_class.unet_name}"
)
self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
text_encoder_key = self.pipeline_class.text_encoder_name
self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
)
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_key = "text_encoder_2"
self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
)
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
def test_inference_load_delete_load_adapters(self): def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes: for scheduler_cls in self.scheduler_classes:
......
...@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import requests_mock import requests_mock
import safetensors.torch
import torch import torch
import torch.nn as nn import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
...@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import ( ...@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import (
backend_max_memory_allocated, backend_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
backend_synchronize, backend_synchronize,
check_if_dicts_are_equal,
get_python_version, get_python_version,
is_torch_compile, is_torch_compile,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
...@@ -1057,11 +1059,10 @@ class ModelTesterMixin: ...@@ -1057,11 +1059,10 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`" " from `_deprecated_kwargs = [<deprecated_argument>]`"
) )
@parameterized.expand([True, False]) @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad() @torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT") @unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_save_load_adapter(self, use_dora=False): def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
import safetensors
from peft import LoraConfig from peft import LoraConfig
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
...@@ -1077,8 +1078,8 @@ class ModelTesterMixin: ...@@ -1077,8 +1078,8 @@ class ModelTesterMixin:
output_no_lora = model(**inputs_dict, return_dict=False)[0] output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=4, r=rank,
lora_alpha=4, lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"], target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
...@@ -1145,6 +1146,90 @@ class ModelTesterMixin: ...@@ -1145,6 +1146,90 @@ class ModelTesterMixin:
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with self.assertRaises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
@require_torch_accelerator @require_torch_accelerator
def test_cpu_offload(self): def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment