Unverified Commit 368958df authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] parse metadata from LoRA and save metadata (#11324)



* feat: parse metadata from lora state dicts.

* tests

* fix tests

* key renaming

* fix

* smol update

* smol updates

* load metadata.

* automatically save metadata in save_lora_adapter.

* propagate changes.

* changes

* add test to models too.

* tigher tests.

* updates

* fixes

* rename tests.

* sorted.

* Update src/diffusers/loaders/lora_base.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* review suggestions.

* removeprefix.

* propagate changes.

* fix-copies

* sd

* docs.

* fixes

* get review ready.

* one more test to catch error.

* change to a different approach.

* fix-copies.

* todo

* sd3

* update

* revert changes in get_peft_kwargs.

* update

* fixes

* fixes

* simplify _load_sft_state_dict_metadata

* update

* style fix

* uipdate

* update

* update

* empty commit

* _pack_dict_with_prefix

* update

* TODO 1.

* todo: 2.

* todo: 3.

* update

* update

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* reraise.

* move argument.

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent e52ceae3
......@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
......
......@@ -159,10 +159,7 @@ class IPAdapterMixin:
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
......@@ -465,10 +462,7 @@ class FluxIPAdapterMixin:
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
......@@ -750,10 +744,7 @@ class SD3IPAdapterMixin:
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
......
......@@ -14,6 +14,7 @@
import copy
import inspect
import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Union
......@@ -45,6 +46,7 @@ from ..utils import (
set_adapter_layers,
set_weights_and_activate_adapters,
)
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
if is_transformers_available():
......@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
......@@ -206,6 +209,7 @@ def _fetch_state_dict(
subfolder,
user_agent,
allow_pickle,
metadata=None,
):
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
......@@ -236,11 +240,14 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
metadata = _load_sft_state_dict_metadata(model_file)
except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle:
raise e
# try loading non-safetensors weights
model_file = None
metadata = None
pass
if model_file is None:
......@@ -261,10 +268,11 @@ def _fetch_state_dict(
user_agent=user_agent,
)
state_dict = load_state_dict(model_file)
metadata = None
else:
state_dict = pretrained_model_name_or_path_or_dict
return state_dict
return state_dict, metadata
def _best_guess_weight_name(
......@@ -306,6 +314,11 @@ def _best_guess_weight_name(
return weight_name
def _pack_dict_with_prefix(state_dict, prefix):
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
return sd_with_prefix
def _load_lora_into_text_encoder(
state_dict,
network_alphas,
......@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
if network_alphas and metadata:
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
peft_kwargs = {}
if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"):
......@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
logger.info(f"Loading {prefix}.")
......@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]:
......@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name
if adapter_name is None:
......@@ -889,8 +914,7 @@ class LoraBaseMixin:
@staticmethod
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
return _pack_dict_with_prefix(layers_weights, prefix)
@staticmethod
def write_lora_layers(
......@@ -900,16 +924,32 @@ class LoraBaseMixin:
weight_name: str,
save_function: Callable,
safe_serialization: bool,
lora_adapter_metadata: Optional[dict] = None,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if lora_adapter_metadata and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
lora_adapter_metadata, indent=2, sort_keys=True
)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save
......
......@@ -37,6 +37,7 @@ from .lora_base import ( # noqa
LoraBaseMixin,
_fetch_state_dict,
_load_lora_into_text_encoder,
_pack_dict_with_prefix,
)
from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
......@@ -202,7 +203,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -213,6 +215,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -226,6 +229,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale,
adapter_name=adapter_name,
_pipeline=self,
metadata=metadata,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
......@@ -282,6 +286,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
......@@ -295,18 +301,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -343,7 +347,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
return out
@classmethod
def load_lora_into_unet(
......@@ -355,6 +360,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
......@@ -378,6 +384,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
......@@ -396,6 +405,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -413,6 +423,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
......@@ -440,6 +451,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
......@@ -449,6 +463,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -464,6 +479,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
unet_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
......@@ -486,8 +503,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
unet_lora_adapter_metadata:
LoRA adapter metadata associated with the unet to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
......@@ -498,6 +520,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if unet_lora_adapter_metadata:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
......@@ -506,6 +536,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
......@@ -641,7 +672,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
......@@ -656,6 +688,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
network_alphas=network_alphas,
unet=self.unet,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -667,6 +700,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -678,6 +712,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -736,6 +771,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both.
......@@ -749,18 +786,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -797,7 +832,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas
out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
return out
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
......@@ -810,6 +846,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
......@@ -833,6 +870,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
......@@ -851,6 +891,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=cls.unet_name,
network_alphas=network_alphas,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -869,6 +910,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
......@@ -896,6 +938,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
......@@ -905,6 +950,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -921,6 +967,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
unet_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
text_encoder_2_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
......@@ -946,8 +995,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
unet_lora_adapter_metadata:
LoRA adapter metadata associated with the unet to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
......@@ -963,6 +1019,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if unet_lora_adapter_metadata is not None:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
......@@ -970,6 +1039,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
......@@ -1103,6 +1173,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
......@@ -1116,18 +1188,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -1148,7 +1218,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights(
self,
......@@ -1197,7 +1268,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -1207,6 +1279,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -1218,6 +1291,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -1229,6 +1303,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -1236,7 +1311,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
@classmethod
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -1256,6 +1338,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -1268,6 +1353,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -1286,6 +1372,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
......@@ -1313,6 +1400,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
......@@ -1322,6 +1412,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -1339,6 +1430,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
text_encoder_2_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
......@@ -1364,8 +1458,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
......@@ -1381,6 +1482,21 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
......@@ -1388,6 +1504,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
......@@ -1519,6 +1636,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
......@@ -1532,18 +1651,16 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -1564,7 +1681,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -1609,7 +1727,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -1619,6 +1738,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -1627,7 +1747,14 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -1647,6 +1774,9 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -1659,6 +1789,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -1674,9 +1805,10 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -1693,14 +1825,21 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -1710,6 +1849,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
......@@ -1843,7 +1983,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
......@@ -1856,18 +1997,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -1921,8 +2060,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
)
if return_alphas:
return state_dict, network_alphas
if return_alphas or return_lora_metadata:
outputs = [state_dict]
if return_alphas:
outputs.append(network_alphas)
if return_lora_metadata:
outputs.append(metadata)
return tuple(outputs)
else:
return state_dict
......@@ -1973,7 +2117,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
)
......@@ -2024,6 +2169,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
network_alphas=network_alphas,
transformer=transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2043,6 +2189,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix=self.text_encoder_name,
lora_scale=self.lora_scale,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2055,6 +2202,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
network_alphas,
transformer,
adapter_name=None,
metadata=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
......@@ -2081,6 +2229,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
......@@ -2093,6 +2244,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2165,6 +2317,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
......@@ -2192,6 +2345,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
......@@ -2201,6 +2357,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2217,6 +2374,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
......@@ -2239,8 +2398,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
......@@ -2251,6 +2415,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if transformer_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
......@@ -2259,6 +2433,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
......@@ -2626,6 +2801,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
network_alphas,
transformer,
adapter_name=None,
metadata=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
......@@ -2652,6 +2828,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError(
......@@ -2664,6 +2843,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
state_dict,
network_alphas=network_alphas,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2682,6 +2862,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
......@@ -2709,6 +2890,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
_load_lora_into_text_encoder(
state_dict=state_dict,
......@@ -2718,6 +2902,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
prefix=prefix,
text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2837,6 +3022,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
......@@ -2850,18 +3037,16 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -2882,7 +3067,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights(
self,
......@@ -2926,7 +3112,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -2936,6 +3123,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2944,7 +3132,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -2964,6 +3159,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -2976,6 +3174,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -2991,9 +3190,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -3010,14 +3210,21 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -3027,6 +3234,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
def fuse_lora(
......@@ -3153,6 +3361,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
......@@ -3166,18 +3376,16 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -3198,7 +3406,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -3243,7 +3452,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -3253,6 +3463,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -3261,7 +3472,14 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -3281,6 +3499,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -3293,6 +3514,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -3308,9 +3530,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -3327,14 +3550,21 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -3344,6 +3574,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
......@@ -3471,7 +3702,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
......@@ -3484,18 +3716,16 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -3520,7 +3750,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict)
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -3565,7 +3796,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -3575,6 +3807,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -3583,7 +3816,14 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -3603,6 +3843,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -3615,6 +3858,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -3630,9 +3874,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -3649,14 +3894,21 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -3666,6 +3918,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
......@@ -3794,6 +4047,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
......@@ -3807,18 +4062,16 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -3839,7 +4092,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -3884,7 +4138,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -3894,6 +4149,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -3902,7 +4158,14 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -3922,6 +4185,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -3934,6 +4200,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -3949,9 +4216,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -3968,14 +4236,21 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -3985,6 +4260,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
......@@ -4112,7 +4388,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
......@@ -4125,18 +4402,16 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -4161,7 +4436,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
if is_original_hunyuan_video:
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -4206,7 +4482,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -4216,6 +4493,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -4224,7 +4502,14 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -4244,6 +4529,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -4256,6 +4544,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -4271,9 +4560,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -4290,14 +4580,21 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -4307,6 +4604,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
......@@ -4434,7 +4732,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
......@@ -4447,18 +4746,16 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -4484,7 +4781,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
if non_diffusers:
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -4529,7 +4827,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -4539,6 +4838,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -4547,7 +4847,14 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -4567,6 +4874,9 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -4579,6 +4889,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -4594,9 +4905,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -4613,14 +4925,21 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -4630,6 +4949,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
......@@ -4757,7 +5077,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
......@@ -4770,18 +5091,16 @@ class WanLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -4806,7 +5125,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
@classmethod
def _maybe_expand_t2v_lora_for_i2v(
......@@ -4898,7 +5218,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
state_dict = self._maybe_expand_t2v_lora_for_i2v(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
......@@ -4912,6 +5233,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -4920,7 +5242,14 @@ class WanLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -4940,6 +5269,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -4952,6 +5284,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -4967,9 +5300,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -4986,14 +5320,21 @@ class WanLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -5003,6 +5344,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
......@@ -5131,6 +5473,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
......@@ -5144,18 +5488,16 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -5176,7 +5518,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -5221,7 +5564,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -5231,6 +5575,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -5239,7 +5584,14 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -5259,6 +5611,9 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -5271,6 +5626,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -5286,9 +5642,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -5305,14 +5662,21 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -5322,6 +5686,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
......@@ -5449,7 +5814,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
allowed by Git.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
......@@ -5462,18 +5828,16 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -5498,7 +5862,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
if is_non_diffusers_format:
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
return state_dict
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights(
......@@ -5543,7 +5908,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
......@@ -5553,6 +5919,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -5561,7 +5928,14 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `transformer`.
......@@ -5581,6 +5955,9 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
weights.
hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
......@@ -5593,6 +5970,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
......@@ -5608,9 +5986,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Save the LoRA parameters corresponding to the transformer.
Arguments:
save_directory (`str` or `os.PathLike`):
......@@ -5627,14 +6006,21 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
"""
state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers:
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model
cls.write_lora_layers(
......@@ -5644,6 +6030,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
)
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
......
......@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import json
import os
from functools import partial
from pathlib import Path
......@@ -185,6 +186,7 @@ class PeftAdapterMixin:
Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata: TODO
"""
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer
......@@ -202,6 +204,7 @@ class PeftAdapterMixin:
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
metadata = kwargs.pop("metadata", None)
allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
......@@ -209,12 +212,9 @@ class PeftAdapterMixin:
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict = _fetch_state_dict(
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
......@@ -227,12 +227,17 @@ class PeftAdapterMixin:
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
metadata=metadata,
)
if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if network_alphas and metadata:
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
......@@ -267,7 +272,12 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
}
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict)
if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs:
......@@ -290,7 +300,11 @@ class PeftAdapterMixin:
if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs)
try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name
if adapter_name is None:
adapter_name = get_adapter_name(self)
......@@ -445,17 +459,13 @@ class PeftAdapterMixin:
underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
"""
from peft.utils import get_peft_model_state_dict
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if adapter_name is None:
adapter_name = get_adapter_name(self)
......@@ -463,6 +473,8 @@ class PeftAdapterMixin:
if adapter_name not in getattr(self, "peft_config", {}):
raise ValueError(f"Adapter name {adapter_name} not found in the model.")
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
lora_layers_to_save = get_peft_model_state_dict(
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
)
......@@ -472,7 +484,15 @@ class PeftAdapterMixin:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
# Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else:
save_function = torch.save
......@@ -485,7 +505,6 @@ class PeftAdapterMixin:
else:
weight_name = LORA_WEIGHT_NAME
# TODO: we could consider saving the `peft_config` as well.
save_path = Path(save_directory, weight_name).as_posix()
save_function(lora_layers_to_save, save_path)
logger.info(f"Model weights saved in {save_path}")
......
......@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
......
......@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
"""
import enum
import json
from .import_utils import is_torch_available
from .logging import get_logger
......@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None):
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
return all(torch.all(param == 0).item() for param in state_dict.values())
def _load_sft_state_dict_metadata(model_file: str):
import safetensors.torch
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
......@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b):
return distance
def check_if_dicts_are_equal(dict1, dict2):
dict1, dict2 = dict1.copy(), dict2.copy()
for key, value in dict1.items():
if isinstance(value, set):
dict1[key] = sorted(value)
for key, value in dict2.items():
if isinstance(value, set):
dict2[key] = sorted(value)
for key in dict1:
if key not in dict2:
return False
if dict1[key] != dict2[key]:
return False
for key in dict2:
if key not in dict1:
return False
return True
def print_tensor_test(
tensor,
limit_to_slices=None,
......
......@@ -24,11 +24,7 @@ from diffusers import (
WanPipeline,
WanTransformer3DModel,
)
from diffusers.utils.testing_utils import (
floats_tensor,
require_peft_backend,
skip_mps,
)
from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
sys.path.append(".")
......
......@@ -22,6 +22,7 @@ from itertools import product
import numpy as np
import pytest
import torch
from parameterized import parameterized
from diffusers import (
AutoencoderKL,
......@@ -33,6 +34,7 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import (
CaptureLogger,
check_if_dicts_are_equal,
floats_tensor,
is_torch_version,
require_peft_backend,
......@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool:
return False
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
extracted = {
k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
}
check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
def initialize_dummy_state_dict(state_dict):
if not all(v.device.type == "meta" for _, v in state_dict.items()):
raise ValueError("`state_dict` has non-meta values.")
......@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False):
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders:
......@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
torch.manual_seed(0)
if self.unet_kwargs is not None:
......@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests:
text_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules,
init_lora_weights=False,
use_dora=use_dora,
......@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests:
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=rank,
lora_alpha=lora_alpha,
target_modules=self.denoiser_target_modules,
init_lora_weights=False,
use_dora=use_dora,
......@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts
def _get_lora_adapter_metadata(self, modules_to_save):
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None:
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
return metadatas
def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
......@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0]
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
if len(out) == 3:
_, _, parsed_metadata = out
elif len(out) == 2:
_, parsed_metadata = out
denoiser_key = (
f"{self.pipeline_class.transformer_name}"
if self.transformer_kwargs is not None
else f"{self.pipeline_class.unet_name}"
)
self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
text_encoder_key = self.pipeline_class.text_encoder_name
self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
)
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_key = "text_encoder_2"
self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
)
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes:
......
......@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import requests_mock
import safetensors.torch
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
......@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import (
backend_max_memory_allocated,
backend_reset_peak_memory_stats,
backend_synchronize,
check_if_dicts_are_equal,
get_python_version,
is_torch_compile,
numpy_cosine_similarity_distance,
......@@ -1057,11 +1059,10 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`"
)
@parameterized.expand([True, False])
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_save_load_adapter(self, use_dora=False):
import safetensors
def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
......@@ -1077,8 +1078,8 @@ class ModelTesterMixin:
output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
......@@ -1145,6 +1146,90 @@ class ModelTesterMixin:
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with self.assertRaises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
@require_torch_accelerator
def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment