Unverified Commit 368958df authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] parse metadata from LoRA and save metadata (#11324)



* feat: parse metadata from lora state dicts.

* tests

* fix tests

* key renaming

* fix

* smol update

* smol updates

* load metadata.

* automatically save metadata in save_lora_adapter.

* propagate changes.

* changes

* add test to models too.

* tigher tests.

* updates

* fixes

* rename tests.

* sorted.

* Update src/diffusers/loaders/lora_base.py
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* review suggestions.

* removeprefix.

* propagate changes.

* fix-copies

* sd

* docs.

* fixes

* get review ready.

* one more test to catch error.

* change to a different approach.

* fix-copies.

* todo

* sd3

* update

* revert changes in get_peft_kwargs.

* update

* fixes

* fixes

* simplify _load_sft_state_dict_metadata

* update

* style fix

* uipdate

* update

* update

* empty commit

* _pack_dict_with_prefix

* update

* TODO 1.

* todo: 2.

* todo: 3.

* update

* update

* Apply suggestions from code review
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* reraise.

* move argument.

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarLinoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
parent e52ceae3
...@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline( ...@@ -282,10 +282,7 @@ class IPAdapterFaceIDStableDiffusionPipeline(
revision = kwargs.pop("revision", None) revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
weights_name=weight_name, weights_name=weight_name,
......
...@@ -159,10 +159,7 @@ class IPAdapterMixin: ...@@ -159,10 +159,7 @@ class IPAdapterMixin:
" `low_cpu_mem_usage=False`." " `low_cpu_mem_usage=False`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = [] state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder pretrained_model_name_or_path_or_dict, weight_name, subfolder
...@@ -465,10 +462,7 @@ class FluxIPAdapterMixin: ...@@ -465,10 +462,7 @@ class FluxIPAdapterMixin:
" `low_cpu_mem_usage=False`." " `low_cpu_mem_usage=False`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = [] state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip( for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder pretrained_model_name_or_path_or_dict, weight_name, subfolder
...@@ -750,10 +744,7 @@ class SD3IPAdapterMixin: ...@@ -750,10 +744,7 @@ class SD3IPAdapterMixin:
" `low_cpu_mem_usage=False`." " `low_cpu_mem_usage=False`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file( model_file = _get_model_file(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
import inspect import inspect
import json
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
...@@ -45,6 +46,7 @@ from ..utils import ( ...@@ -45,6 +46,7 @@ from ..utils import (
set_adapter_layers, set_adapter_layers,
set_weights_and_activate_adapters, set_weights_and_activate_adapters,
) )
from ..utils.state_dict_utils import _load_sft_state_dict_metadata
if is_transformers_available(): if is_transformers_available():
...@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__) ...@@ -62,6 +64,7 @@ logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
LORA_ADAPTER_METADATA_KEY = "lora_adapter_metadata"
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None): def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
...@@ -206,6 +209,7 @@ def _fetch_state_dict( ...@@ -206,6 +209,7 @@ def _fetch_state_dict(
subfolder, subfolder,
user_agent, user_agent,
allow_pickle, allow_pickle,
metadata=None,
): ):
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
...@@ -236,11 +240,14 @@ def _fetch_state_dict( ...@@ -236,11 +240,14 @@ def _fetch_state_dict(
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = safetensors.torch.load_file(model_file, device="cpu")
metadata = _load_sft_state_dict_metadata(model_file)
except (IOError, safetensors.SafetensorError) as e: except (IOError, safetensors.SafetensorError) as e:
if not allow_pickle: if not allow_pickle:
raise e raise e
# try loading non-safetensors weights # try loading non-safetensors weights
model_file = None model_file = None
metadata = None
pass pass
if model_file is None: if model_file is None:
...@@ -261,10 +268,11 @@ def _fetch_state_dict( ...@@ -261,10 +268,11 @@ def _fetch_state_dict(
user_agent=user_agent, user_agent=user_agent,
) )
state_dict = load_state_dict(model_file) state_dict = load_state_dict(model_file)
metadata = None
else: else:
state_dict = pretrained_model_name_or_path_or_dict state_dict = pretrained_model_name_or_path_or_dict
return state_dict return state_dict, metadata
def _best_guess_weight_name( def _best_guess_weight_name(
...@@ -306,6 +314,11 @@ def _best_guess_weight_name( ...@@ -306,6 +314,11 @@ def _best_guess_weight_name(
return weight_name return weight_name
def _pack_dict_with_prefix(state_dict, prefix):
sd_with_prefix = {f"{prefix}.{key}": value for key, value in state_dict.items()}
return sd_with_prefix
def _load_lora_into_text_encoder( def _load_lora_into_text_encoder(
state_dict, state_dict,
network_alphas, network_alphas,
...@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder( ...@@ -317,10 +330,14 @@ def _load_lora_into_text_encoder(
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
if network_alphas and metadata:
raise ValueError("`network_alphas` and `metadata` cannot be specified both at the same time.")
peft_kwargs = {} peft_kwargs = {}
if low_cpu_mem_usage: if low_cpu_mem_usage:
if not is_peft_version(">=", "0.13.1"): if not is_peft_version(">=", "0.13.1"):
...@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder( ...@@ -349,6 +366,8 @@ def _load_lora_into_text_encoder(
# Load the layers corresponding to text encoder and make necessary adjustments. # Load the layers corresponding to text encoder and make necessary adjustments.
if prefix is not None: if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0: if len(state_dict) > 0:
logger.info(f"Loading {prefix}.") logger.info(f"Loading {prefix}.")
...@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder( ...@@ -376,7 +395,10 @@ def _load_lora_into_text_encoder(
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix] alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys} network_alphas = {k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys}
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False) if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
if lora_config_kwargs["use_dora"]: if lora_config_kwargs["use_dora"]:
...@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder( ...@@ -398,7 +420,10 @@ def _load_lora_into_text_encoder(
if is_peft_version("<=", "0.13.2"): if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias") lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs) try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name # adapter_name
if adapter_name is None: if adapter_name is None:
...@@ -889,8 +914,7 @@ class LoraBaseMixin: ...@@ -889,8 +914,7 @@ class LoraBaseMixin:
@staticmethod @staticmethod
def pack_weights(layers, prefix): def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} return _pack_dict_with_prefix(layers_weights, prefix)
return layers_state_dict
@staticmethod @staticmethod
def write_lora_layers( def write_lora_layers(
...@@ -900,16 +924,32 @@ class LoraBaseMixin: ...@@ -900,16 +924,32 @@ class LoraBaseMixin:
weight_name: str, weight_name: str,
save_function: Callable, save_function: Callable,
safe_serialization: bool, safe_serialization: bool,
lora_adapter_metadata: Optional[dict] = None,
): ):
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
if lora_adapter_metadata and not safe_serialization:
raise ValueError("`lora_adapter_metadata` cannot be specified when not using `safe_serialization`.")
if lora_adapter_metadata and not isinstance(lora_adapter_metadata, dict):
raise TypeError("`lora_adapter_metadata` must be of type `dict`.")
if save_function is None: if save_function is None:
if safe_serialization: if safe_serialization:
def save_function(weights, filename): def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) # Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(
lora_adapter_metadata, indent=2, sort_keys=True
)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else: else:
save_function = torch.save save_function = torch.save
......
...@@ -37,6 +37,7 @@ from .lora_base import ( # noqa ...@@ -37,6 +37,7 @@ from .lora_base import ( # noqa
LoraBaseMixin, LoraBaseMixin,
_fetch_state_dict, _fetch_state_dict,
_load_lora_into_text_encoder, _load_lora_into_text_encoder,
_pack_dict_with_prefix,
) )
from .lora_conversion_utils import ( from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers, _convert_bfl_flux_control_lora_to_diffusers,
...@@ -202,7 +203,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -202,7 +203,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -213,6 +215,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -213,6 +215,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
network_alphas=network_alphas, network_alphas=network_alphas,
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet, unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -226,6 +229,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -226,6 +229,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
_pipeline=self, _pipeline=self,
metadata=metadata,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
) )
...@@ -282,6 +286,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -282,6 +286,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None): weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file. Name of the serialized state dict file.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both. # UNet and text encoder or both.
...@@ -295,18 +301,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -295,18 +301,16 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None) unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -343,7 +347,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -343,7 +347,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
return out
@classmethod @classmethod
def load_lora_into_unet( def load_lora_into_unet(
...@@ -355,6 +360,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -355,6 +360,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
...@@ -378,6 +384,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -378,6 +384,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -396,6 +405,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -396,6 +405,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
prefix=cls.unet_name, prefix=cls.unet_name,
network_alphas=network_alphas, network_alphas=network_alphas,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -413,6 +423,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -413,6 +423,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -440,6 +451,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -440,6 +451,9 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
_load_lora_into_text_encoder( _load_lora_into_text_encoder(
state_dict=state_dict, state_dict=state_dict,
...@@ -449,6 +463,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -449,6 +463,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
prefix=prefix, prefix=prefix,
text_encoder_name=cls.text_encoder_name, text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -464,6 +479,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -464,6 +479,8 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
unet_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the UNet and text encoder.
...@@ -486,8 +503,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -486,8 +503,13 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
unet_lora_adapter_metadata:
LoRA adapter metadata associated with the unet to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers): if not (unet_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.") raise ValueError("You must pass at least one of `unet_lora_layers` and `text_encoder_lora_layers`.")
...@@ -498,6 +520,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -498,6 +520,14 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if unet_lora_adapter_metadata:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
...@@ -506,6 +536,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin): ...@@ -506,6 +536,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -641,7 +672,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -641,7 +672,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict( kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config, unet_config=self.unet.config,
**kwargs, **kwargs,
...@@ -656,6 +688,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -656,6 +688,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
network_alphas=network_alphas, network_alphas=network_alphas,
unet=self.unet, unet=self.unet,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -667,6 +700,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -667,6 +700,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=self.text_encoder_name, prefix=self.text_encoder_name,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -678,6 +712,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -678,6 +712,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=f"{self.text_encoder_name}_2", prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -736,6 +771,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -736,6 +771,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
weight_name (`str`, *optional*, defaults to None): weight_name (`str`, *optional*, defaults to None):
Name of the serialized state dict file. Name of the serialized state dict file.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# UNet and text encoder or both. # UNet and text encoder or both.
...@@ -749,18 +786,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -749,18 +786,16 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
unet_config = kwargs.pop("unet_config", None) unet_config = kwargs.pop("unet_config", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -797,7 +832,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -797,7 +832,8 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict) state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
return state_dict, network_alphas out = (state_dict, network_alphas, metadata) if return_lora_metadata else (state_dict, network_alphas)
return out
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet # Copied from diffusers.loaders.lora_pipeline.StableDiffusionLoraLoaderMixin.load_lora_into_unet
...@@ -810,6 +846,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -810,6 +846,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
...@@ -833,6 +870,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -833,6 +870,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if not USE_PEFT_BACKEND: if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.") raise ValueError("PEFT backend is required for this method.")
...@@ -851,6 +891,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -851,6 +891,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=cls.unet_name, prefix=cls.unet_name,
network_alphas=network_alphas, network_alphas=network_alphas,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -869,6 +910,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -869,6 +910,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -896,6 +938,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -896,6 +938,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
_load_lora_into_text_encoder( _load_lora_into_text_encoder(
state_dict=state_dict, state_dict=state_dict,
...@@ -905,6 +950,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -905,6 +950,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
prefix=prefix, prefix=prefix,
text_encoder_name=cls.text_encoder_name, text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -921,6 +967,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -921,6 +967,9 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
unet_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
text_encoder_2_lora_adapter_metadata=None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the UNet and text encoder.
...@@ -946,8 +995,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -946,8 +995,15 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
unet_lora_adapter_metadata:
LoRA adapter metadata associated with the unet to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError( raise ValueError(
...@@ -963,6 +1019,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -963,6 +1019,19 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if unet_lora_adapter_metadata is not None:
lora_adapter_metadata.update(_pack_dict_with_prefix(unet_lora_adapter_metadata, cls.unet_name))
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
...@@ -970,6 +1039,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin): ...@@ -970,6 +1039,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -1103,6 +1173,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1103,6 +1173,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
...@@ -1116,18 +1188,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1116,18 +1188,16 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -1148,7 +1218,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1148,7 +1218,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights( def load_lora_weights(
self, self,
...@@ -1197,7 +1268,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1197,7 +1268,8 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -1207,6 +1279,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1207,6 +1279,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -1218,6 +1291,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1218,6 +1291,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=self.text_encoder_name, prefix=self.text_encoder_name,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -1229,6 +1303,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1229,6 +1303,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=f"{self.text_encoder_name}_2", prefix=f"{self.text_encoder_name}_2",
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -1236,7 +1311,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1236,7 +1311,14 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -1256,6 +1338,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1256,6 +1338,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -1268,6 +1353,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1268,6 +1353,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -1286,6 +1372,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1286,6 +1372,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -1313,6 +1400,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1313,6 +1400,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
_load_lora_into_text_encoder( _load_lora_into_text_encoder(
state_dict=state_dict, state_dict=state_dict,
...@@ -1322,6 +1412,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1322,6 +1412,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
prefix=prefix, prefix=prefix,
text_encoder_name=cls.text_encoder_name, text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -1339,6 +1430,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1339,6 +1430,9 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
text_encoder_2_lora_adapter_metadata=None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the UNet and text encoder.
...@@ -1364,8 +1458,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1364,8 +1458,15 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
text_encoder_2_lora_adapter_metadata:
LoRA adapter metadata associated with the second text encoder to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): if not (transformer_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError( raise ValueError(
...@@ -1381,6 +1482,21 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1381,6 +1482,21 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
if text_encoder_2_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_2_lora_adapter_metadata, "text_encoder_2")
)
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
...@@ -1388,6 +1504,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin): ...@@ -1388,6 +1504,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer # Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
...@@ -1519,6 +1636,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1519,6 +1636,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
...@@ -1532,18 +1651,16 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1532,18 +1651,16 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -1564,7 +1681,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1564,7 +1681,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -1609,7 +1727,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1609,7 +1727,8 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -1619,6 +1738,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1619,6 +1738,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -1627,7 +1747,14 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1627,7 +1747,14 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->AuraFlowTransformer2DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -1647,6 +1774,9 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1647,6 +1774,9 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -1659,6 +1789,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1659,6 +1789,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -1674,9 +1805,10 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1674,9 +1805,10 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -1693,14 +1825,21 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1693,14 +1825,21 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -1710,6 +1849,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin): ...@@ -1710,6 +1849,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
...@@ -1843,7 +1983,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1843,7 +1983,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both. # transformer and text encoder or both.
...@@ -1856,18 +1997,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1856,18 +1997,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -1921,8 +2060,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1921,8 +2060,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue." f"The alpha key ({k}) seems to be incorrect. If you think this error is unexpected, please open as issue."
) )
if return_alphas: if return_alphas or return_lora_metadata:
return state_dict, network_alphas outputs = [state_dict]
if return_alphas:
outputs.append(network_alphas)
if return_lora_metadata:
outputs.append(metadata)
return tuple(outputs)
else: else:
return state_dict return state_dict
...@@ -1973,7 +2117,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -1973,7 +2117,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict( kwargs["return_lora_metadata"] = True
state_dict, network_alphas, metadata = self.lora_state_dict(
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
) )
...@@ -2024,6 +2169,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2024,6 +2169,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
network_alphas=network_alphas, network_alphas=network_alphas,
transformer=transformer, transformer=transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2043,6 +2189,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2043,6 +2189,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix=self.text_encoder_name, prefix=self.text_encoder_name,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2055,6 +2202,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2055,6 +2202,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
network_alphas, network_alphas,
transformer, transformer,
adapter_name=None, adapter_name=None,
metadata=None,
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
...@@ -2081,6 +2229,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2081,6 +2229,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError( raise ValueError(
...@@ -2093,6 +2244,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2093,6 +2244,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=network_alphas, network_alphas=network_alphas,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2165,6 +2317,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2165,6 +2317,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -2192,6 +2345,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2192,6 +2345,9 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
_load_lora_into_text_encoder( _load_lora_into_text_encoder(
state_dict=state_dict, state_dict=state_dict,
...@@ -2201,6 +2357,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2201,6 +2357,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
prefix=prefix, prefix=prefix,
text_encoder_name=cls.text_encoder_name, text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2217,6 +2374,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2217,6 +2374,8 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata=None,
text_encoder_lora_adapter_metadata=None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the UNet and text encoder.
...@@ -2239,8 +2398,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2239,8 +2398,13 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
text_encoder_lora_adapter_metadata:
LoRA adapter metadata associated with the text encoder to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not (transformer_lora_layers or text_encoder_lora_layers): if not (transformer_lora_layers or text_encoder_lora_layers):
raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.") raise ValueError("You must pass at least one of `transformer_lora_layers` and `text_encoder_lora_layers`.")
...@@ -2251,6 +2415,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2251,6 +2415,16 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name)) state_dict.update(cls.pack_weights(text_encoder_lora_layers, cls.text_encoder_name))
if transformer_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
if text_encoder_lora_adapter_metadata:
lora_adapter_metadata.update(
_pack_dict_with_prefix(text_encoder_lora_adapter_metadata, cls.text_encoder_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
...@@ -2259,6 +2433,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin): ...@@ -2259,6 +2433,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -2626,6 +2801,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2626,6 +2801,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
network_alphas, network_alphas,
transformer, transformer,
adapter_name=None, adapter_name=None,
metadata=None,
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
...@@ -2652,6 +2828,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2652,6 +2828,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"): if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
raise ValueError( raise ValueError(
...@@ -2664,6 +2843,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2664,6 +2843,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
state_dict, state_dict,
network_alphas=network_alphas, network_alphas=network_alphas,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2682,6 +2862,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2682,6 +2862,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
_pipeline=None, _pipeline=None,
low_cpu_mem_usage=False, low_cpu_mem_usage=False,
hotswap: bool = False, hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
...@@ -2709,6 +2890,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2709,6 +2890,9 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
_load_lora_into_text_encoder( _load_lora_into_text_encoder(
state_dict=state_dict, state_dict=state_dict,
...@@ -2718,6 +2902,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin): ...@@ -2718,6 +2902,7 @@ class AmusedLoraLoaderMixin(StableDiffusionLoraLoaderMixin):
prefix=prefix, prefix=prefix,
text_encoder_name=cls.text_encoder_name, text_encoder_name=cls.text_encoder_name,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2837,6 +3022,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2837,6 +3022,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
...@@ -2850,18 +3037,16 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2850,18 +3037,16 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -2882,7 +3067,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2882,7 +3067,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights( def load_lora_weights(
self, self,
...@@ -2926,7 +3112,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2926,7 +3112,8 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -2936,6 +3123,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2936,6 +3123,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2944,7 +3132,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2944,7 +3132,14 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogVideoXTransformer3DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -2964,6 +3159,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2964,6 +3159,9 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -2976,6 +3174,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2976,6 +3174,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -2991,9 +3190,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -2991,9 +3190,10 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -3010,14 +3210,21 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -3010,14 +3210,21 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -3027,6 +3234,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin): ...@@ -3027,6 +3234,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
def fuse_lora( def fuse_lora(
...@@ -3153,6 +3361,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3153,6 +3361,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
...@@ -3166,18 +3376,16 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3166,18 +3376,16 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -3198,7 +3406,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3198,7 +3406,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -3243,7 +3452,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3243,7 +3452,8 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -3253,6 +3463,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3253,6 +3463,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -3261,7 +3472,14 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3261,7 +3472,14 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->MochiTransformer3DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -3281,6 +3499,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3281,6 +3499,9 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -3293,6 +3514,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3293,6 +3514,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -3308,9 +3530,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3308,9 +3530,10 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -3327,14 +3550,21 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3327,14 +3550,21 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -3344,6 +3574,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin): ...@@ -3344,6 +3574,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -3471,7 +3702,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3471,7 +3702,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both. # transformer and text encoder or both.
...@@ -3484,18 +3716,16 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3484,18 +3716,16 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -3520,7 +3750,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3520,7 +3750,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
if is_non_diffusers_format: if is_non_diffusers_format:
state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict) state_dict = _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict)
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -3565,7 +3796,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3565,7 +3796,8 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -3575,6 +3807,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3575,6 +3807,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -3583,7 +3816,14 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3583,7 +3816,14 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->LTXVideoTransformer3DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -3603,6 +3843,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3603,6 +3843,9 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -3615,6 +3858,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3615,6 +3858,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -3630,9 +3874,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3630,9 +3874,10 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -3649,14 +3894,21 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3649,14 +3894,21 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -3666,6 +3918,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -3666,6 +3918,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -3794,6 +4047,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3794,6 +4047,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
...@@ -3807,18 +4062,16 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3807,18 +4062,16 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -3839,7 +4092,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3839,7 +4092,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -3884,7 +4138,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3884,7 +4138,8 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -3894,6 +4149,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3894,6 +4149,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -3902,7 +4158,14 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3902,7 +4158,14 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SanaTransformer2DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -3922,6 +4185,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3922,6 +4185,9 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -3934,6 +4200,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3934,6 +4200,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -3949,9 +4216,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3949,9 +4216,10 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -3968,14 +4236,21 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3968,14 +4236,21 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -3985,6 +4260,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin): ...@@ -3985,6 +4260,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -4112,7 +4388,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4112,7 +4388,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both. # transformer and text encoder or both.
...@@ -4125,18 +4402,16 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4125,18 +4402,16 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -4161,7 +4436,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4161,7 +4436,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
if is_original_hunyuan_video: if is_original_hunyuan_video:
state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict) state_dict = _convert_hunyuan_video_lora_to_diffusers(state_dict)
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -4206,7 +4482,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4206,7 +4482,8 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -4216,6 +4493,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4216,6 +4493,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -4224,7 +4502,14 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4224,7 +4502,14 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HunyuanVideoTransformer3DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -4244,6 +4529,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4244,6 +4529,9 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -4256,6 +4544,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4256,6 +4544,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -4271,9 +4560,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4271,9 +4560,10 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -4290,14 +4580,21 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4290,14 +4580,21 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -4307,6 +4604,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin): ...@@ -4307,6 +4604,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -4434,7 +4732,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4434,7 +4732,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both. # transformer and text encoder or both.
...@@ -4447,18 +4746,16 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4447,18 +4746,16 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -4484,7 +4781,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4484,7 +4781,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
if non_diffusers: if non_diffusers:
state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict) state_dict = _convert_non_diffusers_lumina2_lora_to_diffusers(state_dict)
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -4529,7 +4827,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4529,7 +4827,8 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -4539,6 +4838,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4539,6 +4838,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -4547,7 +4847,14 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4547,7 +4847,14 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->Lumina2Transformer2DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -4567,6 +4874,9 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4567,6 +4874,9 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -4579,6 +4889,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4579,6 +4889,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -4594,9 +4905,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4594,9 +4905,10 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -4613,14 +4925,21 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4613,14 +4925,21 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -4630,6 +4949,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin): ...@@ -4630,6 +4949,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
...@@ -4757,7 +5077,8 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4757,7 +5077,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both. # transformer and text encoder or both.
...@@ -4770,18 +5091,16 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4770,18 +5091,16 @@ class WanLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -4806,7 +5125,8 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4806,7 +5125,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
@classmethod @classmethod
def _maybe_expand_t2v_lora_for_i2v( def _maybe_expand_t2v_lora_for_i2v(
...@@ -4898,7 +5218,8 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4898,7 +5218,8 @@ class WanLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
# convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
state_dict = self._maybe_expand_t2v_lora_for_i2v( state_dict = self._maybe_expand_t2v_lora_for_i2v(
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
...@@ -4912,6 +5233,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4912,6 +5233,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -4920,7 +5242,14 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4920,7 +5242,14 @@ class WanLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -4940,6 +5269,9 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4940,6 +5269,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -4952,6 +5284,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4952,6 +5284,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -4967,9 +5300,10 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4967,9 +5300,10 @@ class WanLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -4986,14 +5320,21 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -4986,14 +5320,21 @@ class WanLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -5003,6 +5344,7 @@ class WanLoraLoaderMixin(LoraBaseMixin): ...@@ -5003,6 +5344,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -5131,6 +5473,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5131,6 +5473,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
...@@ -5144,18 +5488,16 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5144,18 +5488,16 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -5176,7 +5518,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5176,7 +5518,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg) logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -5221,7 +5564,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5221,7 +5564,8 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -5231,6 +5575,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5231,6 +5575,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -5239,7 +5584,14 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5239,7 +5584,14 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -5259,6 +5611,9 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5259,6 +5611,9 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -5271,6 +5626,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5271,6 +5626,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -5286,9 +5642,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5286,9 +5642,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -5305,14 +5662,21 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5305,14 +5662,21 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -5322,6 +5686,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin): ...@@ -5322,6 +5686,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
...@@ -5449,7 +5814,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5449,7 +5814,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
allowed by Git. allowed by Git.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally. The subfolder location of a model file within a larger model repository on the Hub or locally.
return_lora_metadata (`bool`, *optional*, defaults to False):
When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
""" """
# Load the main state dict first which has the LoRA layers for either of # Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both. # transformer and text encoder or both.
...@@ -5462,18 +5828,16 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5462,18 +5828,16 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
subfolder = kwargs.pop("subfolder", None) subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None) weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None) use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False allow_pickle = False
if use_safetensors is None: if use_safetensors is None:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -5498,7 +5862,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5498,7 +5862,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
if is_non_diffusers_format: if is_non_diffusers_format:
state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict) state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
return state_dict out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
def load_lora_weights( def load_lora_weights(
...@@ -5543,7 +5908,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5543,7 +5908,8 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys()) is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format: if not is_correct_format:
...@@ -5553,6 +5919,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5553,6 +5919,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=self, _pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -5561,7 +5928,14 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5561,7 +5928,14 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
def load_lora_into_transformer( def load_lora_into_transformer(
cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `transformer`. This will load the LoRA layers specified in `state_dict` into `transformer`.
...@@ -5581,6 +5955,9 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5581,6 +5955,9 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
weights. weights.
hotswap (`bool`, *optional*): hotswap (`bool`, *optional*):
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
metadata (`dict`):
Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
from the state dict.
""" """
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError( raise ValueError(
...@@ -5593,6 +5970,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5593,6 +5970,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
state_dict, state_dict,
network_alphas=None, network_alphas=None,
adapter_name=adapter_name, adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline, _pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap, hotswap=hotswap,
...@@ -5608,9 +5986,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5608,9 +5986,10 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = True, safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the transformer.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
...@@ -5627,14 +6006,21 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5627,14 +6006,21 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
transformer_lora_adapter_metadata:
LoRA adapter metadata associated with the transformer to be serialized with the state dict.
""" """
state_dict = {} state_dict = {}
lora_adapter_metadata = {}
if not transformer_lora_layers: if not transformer_lora_layers:
raise ValueError("You must pass `transformer_lora_layers`.") raise ValueError("You must pass `transformer_lora_layers`.")
if transformer_lora_layers: state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
if transformer_lora_adapter_metadata is not None:
lora_adapter_metadata.update(
_pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
)
# Save the model # Save the model
cls.write_lora_layers( cls.write_lora_layers(
...@@ -5644,6 +6030,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin): ...@@ -5644,6 +6030,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
weight_name=weight_name, weight_name=weight_name,
save_function=save_function, save_function=save_function,
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
lora_adapter_metadata=lora_adapter_metadata,
) )
# Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
import json
import os import os
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
...@@ -185,6 +186,7 @@ class PeftAdapterMixin: ...@@ -185,6 +186,7 @@ class PeftAdapterMixin:
Note that hotswapping adapters of the text encoder is not yet supported. There are some further Note that hotswapping adapters of the text encoder is not yet supported. There are some further
limitations to this technique, which are documented here: limitations to this technique, which are documented here:
https://huggingface.co/docs/peft/main/en/package_reference/hotswap https://huggingface.co/docs/peft/main/en/package_reference/hotswap
metadata: TODO
""" """
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
from peft.tuners.tuners_utils import BaseTunerLayer from peft.tuners.tuners_utils import BaseTunerLayer
...@@ -202,6 +204,7 @@ class PeftAdapterMixin: ...@@ -202,6 +204,7 @@ class PeftAdapterMixin:
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None) _pipeline = kwargs.pop("_pipeline", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False)
metadata = kwargs.pop("metadata", None)
allow_pickle = False allow_pickle = False
if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"): if low_cpu_mem_usage and is_peft_version("<=", "0.13.0"):
...@@ -209,12 +212,9 @@ class PeftAdapterMixin: ...@@ -209,12 +212,9 @@ class PeftAdapterMixin:
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
) )
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dict = _fetch_state_dict( state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name, weight_name=weight_name,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
...@@ -227,12 +227,17 @@ class PeftAdapterMixin: ...@@ -227,12 +227,17 @@ class PeftAdapterMixin:
subfolder=subfolder, subfolder=subfolder,
user_agent=user_agent, user_agent=user_agent,
allow_pickle=allow_pickle, allow_pickle=allow_pickle,
metadata=metadata,
) )
if network_alphas is not None and prefix is None: if network_alphas is not None and prefix is None:
raise ValueError("`network_alphas` cannot be None when `prefix` is None.") raise ValueError("`network_alphas` cannot be None when `prefix` is None.")
if network_alphas and metadata:
raise ValueError("Both `network_alphas` and `metadata` cannot be specified.")
if prefix is not None: if prefix is not None:
state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")} state_dict = {k.removeprefix(f"{prefix}."): v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
if metadata is not None:
metadata = {k.removeprefix(f"{prefix}."): v for k, v in metadata.items() if k.startswith(f"{prefix}.")}
if len(state_dict) > 0: if len(state_dict) > 0:
if adapter_name in getattr(self, "peft_config", {}) and not hotswap: if adapter_name in getattr(self, "peft_config", {}) and not hotswap:
...@@ -267,7 +272,12 @@ class PeftAdapterMixin: ...@@ -267,7 +272,12 @@ class PeftAdapterMixin:
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
} }
lora_config_kwargs = get_peft_kwargs(rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict) if metadata is not None:
lora_config_kwargs = metadata
else:
lora_config_kwargs = get_peft_kwargs(
rank, network_alpha_dict=network_alphas, peft_state_dict=state_dict
)
_maybe_raise_error_for_ambiguity(lora_config_kwargs) _maybe_raise_error_for_ambiguity(lora_config_kwargs)
if "use_dora" in lora_config_kwargs: if "use_dora" in lora_config_kwargs:
...@@ -290,7 +300,11 @@ class PeftAdapterMixin: ...@@ -290,7 +300,11 @@ class PeftAdapterMixin:
if is_peft_version("<=", "0.13.2"): if is_peft_version("<=", "0.13.2"):
lora_config_kwargs.pop("lora_bias") lora_config_kwargs.pop("lora_bias")
lora_config = LoraConfig(**lora_config_kwargs) try:
lora_config = LoraConfig(**lora_config_kwargs)
except TypeError as e:
raise TypeError("`LoraConfig` class could not be instantiated.") from e
# adapter_name # adapter_name
if adapter_name is None: if adapter_name is None:
adapter_name = get_adapter_name(self) adapter_name = get_adapter_name(self)
...@@ -445,17 +459,13 @@ class PeftAdapterMixin: ...@@ -445,17 +459,13 @@ class PeftAdapterMixin:
underlying model has multiple adapters loaded. underlying model has multiple adapters loaded.
upcast_before_saving (`bool`, defaults to `False`): upcast_before_saving (`bool`, defaults to `False`):
Whether to cast the underlying model to `torch.float32` before serialization. Whether to cast the underlying model to `torch.float32` before serialization.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with. weight_name: (`str`, *optional*, defaults to `None`): Name of the file to serialize the state dict with.
""" """
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
from .lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE from .lora_base import LORA_ADAPTER_METADATA_KEY, LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE
if adapter_name is None: if adapter_name is None:
adapter_name = get_adapter_name(self) adapter_name = get_adapter_name(self)
...@@ -463,6 +473,8 @@ class PeftAdapterMixin: ...@@ -463,6 +473,8 @@ class PeftAdapterMixin:
if adapter_name not in getattr(self, "peft_config", {}): if adapter_name not in getattr(self, "peft_config", {}):
raise ValueError(f"Adapter name {adapter_name} not found in the model.") raise ValueError(f"Adapter name {adapter_name} not found in the model.")
lora_adapter_metadata = self.peft_config[adapter_name].to_dict()
lora_layers_to_save = get_peft_model_state_dict( lora_layers_to_save = get_peft_model_state_dict(
self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name self.to(dtype=torch.float32 if upcast_before_saving else None), adapter_name=adapter_name
) )
...@@ -472,7 +484,15 @@ class PeftAdapterMixin: ...@@ -472,7 +484,15 @@ class PeftAdapterMixin:
if safe_serialization: if safe_serialization:
def save_function(weights, filename): def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) # Inject framework format.
metadata = {"format": "pt"}
if lora_adapter_metadata is not None:
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
return safetensors.torch.save_file(weights, filename, metadata=metadata)
else: else:
save_function = torch.save save_function = torch.save
...@@ -485,7 +505,6 @@ class PeftAdapterMixin: ...@@ -485,7 +505,6 @@ class PeftAdapterMixin:
else: else:
weight_name = LORA_WEIGHT_NAME weight_name = LORA_WEIGHT_NAME
# TODO: we could consider saving the `peft_config` as well.
save_path = Path(save_directory, weight_name).as_posix() save_path = Path(save_directory, weight_name).as_posix()
save_function(lora_layers_to_save, save_path) save_function(lora_layers_to_save, save_path)
logger.info(f"Model weights saved in {save_path}") logger.info(f"Model weights saved in {save_path}")
......
...@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin: ...@@ -155,10 +155,7 @@ class UNet2DConditionLoadersMixin:
use_safetensors = True use_safetensors = True
allow_pickle = True allow_pickle = True
user_agent = { user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
......
...@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily ...@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
""" """
import enum import enum
import json
from .import_utils import is_torch_available from .import_utils import is_torch_available
from .logging import get_logger from .logging import get_logger
...@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None): ...@@ -347,3 +348,16 @@ def state_dict_all_zero(state_dict, filter_str=None):
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)} state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
return all(torch.all(param == 0).item() for param in state_dict.values()) return all(torch.all(param == 0).item() for param in state_dict.values())
def _load_sft_state_dict_metadata(model_file: str):
import safetensors.torch
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
return json.loads(raw) if raw else None
...@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b): ...@@ -133,6 +133,29 @@ def numpy_cosine_similarity_distance(a, b):
return distance return distance
def check_if_dicts_are_equal(dict1, dict2):
dict1, dict2 = dict1.copy(), dict2.copy()
for key, value in dict1.items():
if isinstance(value, set):
dict1[key] = sorted(value)
for key, value in dict2.items():
if isinstance(value, set):
dict2[key] = sorted(value)
for key in dict1:
if key not in dict2:
return False
if dict1[key] != dict2[key]:
return False
for key in dict2:
if key not in dict1:
return False
return True
def print_tensor_test( def print_tensor_test(
tensor, tensor,
limit_to_slices=None, limit_to_slices=None,
......
...@@ -24,11 +24,7 @@ from diffusers import ( ...@@ -24,11 +24,7 @@ from diffusers import (
WanPipeline, WanPipeline,
WanTransformer3DModel, WanTransformer3DModel,
) )
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import floats_tensor, require_peft_backend, skip_mps
floats_tensor,
require_peft_backend,
skip_mps,
)
sys.path.append(".") sys.path.append(".")
......
...@@ -22,6 +22,7 @@ from itertools import product ...@@ -22,6 +22,7 @@ from itertools import product
import numpy as np import numpy as np
import pytest import pytest
import torch import torch
from parameterized import parameterized
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
...@@ -33,6 +34,7 @@ from diffusers.utils import logging ...@@ -33,6 +34,7 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_peft_available from diffusers.utils.import_utils import is_peft_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
check_if_dicts_are_equal,
floats_tensor, floats_tensor,
is_torch_version, is_torch_version,
require_peft_backend, require_peft_backend,
...@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool: ...@@ -71,6 +73,13 @@ def check_if_lora_correctly_set(model) -> bool:
return False return False
def check_module_lora_metadata(parsed_metadata: dict, lora_metadatas: dict, module_key: str):
extracted = {
k.removeprefix(f"{module_key}."): v for k, v in parsed_metadata.items() if k.startswith(f"{module_key}.")
}
check_if_dicts_are_equal(extracted, lora_metadatas[f"{module_key}_lora_adapter_metadata"])
def initialize_dummy_state_dict(state_dict): def initialize_dummy_state_dict(state_dict):
if not all(v.device.type == "meta" for _, v in state_dict.items()): if not all(v.device.type == "meta" for _, v in state_dict.items()):
raise ValueError("`state_dict` has non-meta values.") raise ValueError("`state_dict` has non-meta values.")
...@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests: ...@@ -118,7 +127,7 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def get_dummy_components(self, scheduler_cls=None, use_dora=False): def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs: if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
if self.has_two_text_encoders and self.has_three_text_encoders: if self.has_two_text_encoders and self.has_three_text_encoders:
...@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests: ...@@ -126,6 +135,7 @@ class PeftLoraLoaderMixinTests:
scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls scheduler_cls = self.scheduler_cls if scheduler_cls is None else scheduler_cls
rank = 4 rank = 4
lora_alpha = rank if lora_alpha is None else lora_alpha
torch.manual_seed(0) torch.manual_seed(0)
if self.unet_kwargs is not None: if self.unet_kwargs is not None:
...@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests: ...@@ -161,7 +171,7 @@ class PeftLoraLoaderMixinTests:
text_lora_config = LoraConfig( text_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=lora_alpha,
target_modules=self.text_encoder_target_modules, target_modules=self.text_encoder_target_modules,
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
...@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests: ...@@ -169,7 +179,7 @@ class PeftLoraLoaderMixinTests:
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=rank, r=rank,
lora_alpha=rank, lora_alpha=lora_alpha,
target_modules=self.denoiser_target_modules, target_modules=self.denoiser_target_modules,
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
...@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests: ...@@ -246,6 +256,13 @@ class PeftLoraLoaderMixinTests:
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module) state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
return state_dicts return state_dicts
def _get_lora_adapter_metadata(self, modules_to_save):
metadatas = {}
for module_name, module in modules_to_save.items():
if module is not None:
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
return metadatas
def _get_modules_to_save(self, pipe, has_denoiser=False): def _get_modules_to_save(self, pipe, has_denoiser=False):
modules_to_save = {} modules_to_save = {}
lora_loadable_modules = self.pipeline_class._lora_loadable_modules lora_loadable_modules = self.pipeline_class._lora_loadable_modules
...@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests: ...@@ -2214,6 +2231,86 @@ class PeftLoraLoaderMixinTests:
_, _, inputs = self.get_dummy_inputs(with_generator=False) _, _, inputs = self.get_dummy_inputs(with_generator=False)
pipe(**inputs, generator=torch.manual_seed(0))[0] pipe(**inputs, generator=torch.manual_seed(0))[0]
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_is_loaded_correctly(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
out = pipe.lora_state_dict(tmpdir, return_lora_metadata=True)
if len(out) == 3:
_, _, parsed_metadata = out
elif len(out) == 2:
_, parsed_metadata = out
denoiser_key = (
f"{self.pipeline_class.transformer_name}"
if self.transformer_kwargs is not None
else f"{self.pipeline_class.unet_name}"
)
self.assertTrue(any(k.startswith(f"{denoiser_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=denoiser_key
)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
text_encoder_key = self.pipeline_class.text_encoder_name
self.assertTrue(any(k.startswith(f"{text_encoder_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_key
)
if "text_encoder_2" in self.pipeline_class._lora_loadable_modules:
text_encoder_2_key = "text_encoder_2"
self.assertTrue(any(k.startswith(f"{text_encoder_2_key}.") for k in parsed_metadata))
check_module_lora_metadata(
parsed_metadata=parsed_metadata, lora_metadatas=lora_metadatas, module_key=text_encoder_2_key
)
@parameterized.expand([4, 8, 16])
def test_lora_adapter_metadata_save_load_inference(self, lora_alpha):
scheduler_cls = self.scheduler_classes[0]
components, text_lora_config, denoiser_lora_config = self.get_dummy_components(
scheduler_cls, lora_alpha=lora_alpha
)
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.check_if_adapters_added_correctly(
pipe, text_lora_config=text_lora_config, denoiser_lora_config=denoiser_lora_config
)
output_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
with tempfile.TemporaryDirectory() as tmpdir:
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
lora_state_dicts = self._get_lora_state_dicts(modules_to_save)
lora_metadatas = self._get_lora_adapter_metadata(modules_to_save)
self.pipeline_class.save_lora_weights(save_directory=tmpdir, **lora_state_dicts, **lora_metadatas)
pipe.unload_lora_weights()
pipe.load_lora_weights(tmpdir)
output_lora_pretrained = pipe(**inputs, generator=torch.manual_seed(0))[0]
self.assertTrue(
np.allclose(output_lora, output_lora_pretrained, atol=1e-3, rtol=1e-3), "Lora outputs should match."
)
def test_inference_load_delete_load_adapters(self): def test_inference_load_delete_load_adapters(self):
"Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works." "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
for scheduler_cls in self.scheduler_classes: for scheduler_cls in self.scheduler_classes:
......
...@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union ...@@ -30,6 +30,7 @@ from typing import Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import requests_mock import requests_mock
import safetensors.torch
import torch import torch
import torch.nn as nn import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
...@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import ( ...@@ -62,6 +63,7 @@ from diffusers.utils.testing_utils import (
backend_max_memory_allocated, backend_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
backend_synchronize, backend_synchronize,
check_if_dicts_are_equal,
get_python_version, get_python_version,
is_torch_compile, is_torch_compile,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
...@@ -1057,11 +1059,10 @@ class ModelTesterMixin: ...@@ -1057,11 +1059,10 @@ class ModelTesterMixin:
" from `_deprecated_kwargs = [<deprecated_argument>]`" " from `_deprecated_kwargs = [<deprecated_argument>]`"
) )
@parameterized.expand([True, False]) @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad() @torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT") @unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_save_load_adapter(self, use_dora=False): def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False):
import safetensors
from peft import LoraConfig from peft import LoraConfig
from peft.utils import get_peft_model_state_dict from peft.utils import get_peft_model_state_dict
...@@ -1077,8 +1078,8 @@ class ModelTesterMixin: ...@@ -1077,8 +1078,8 @@ class ModelTesterMixin:
output_no_lora = model(**inputs_dict, return_dict=False)[0] output_no_lora = model(**inputs_dict, return_dict=False)[0]
denoiser_lora_config = LoraConfig( denoiser_lora_config = LoraConfig(
r=4, r=rank,
lora_alpha=4, lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"], target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False, init_lora_weights=False,
use_dora=use_dora, use_dora=use_dora,
...@@ -1145,6 +1146,90 @@ class ModelTesterMixin: ...@@ -1145,6 +1146,90 @@ class ModelTesterMixin:
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)])
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora):
from peft import LoraConfig
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=rank,
lora_alpha=lora_alpha,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=use_dora,
)
model.add_adapter(denoiser_lora_config)
metadata = model.peft_config["default"].to_dict()
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
parsed_metadata = model.peft_config["default_0"].to_dict()
check_if_dicts_are_equal(metadata, parsed_metadata)
@torch.no_grad()
@unittest.skipIf(not is_peft_available(), "Only with PEFT")
def test_lora_adapter_wrong_metadata_raises_error(self):
from peft import LoraConfig
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
from diffusers.loaders.peft import PeftAdapterMixin
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)
if not issubclass(model.__class__, PeftAdapterMixin):
return
denoiser_lora_config = LoraConfig(
r=4,
lora_alpha=4,
target_modules=["to_q", "to_k", "to_v", "to_out.0"],
init_lora_weights=False,
use_dora=False,
)
model.add_adapter(denoiser_lora_config)
self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with tempfile.TemporaryDirectory() as tmpdir:
model.save_lora_adapter(tmpdir)
model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(model_file))
# Perturb the metadata in the state dict.
loaded_state_dict = safetensors.torch.load_file(model_file)
metadata = {"format": "pt"}
lora_adapter_metadata = denoiser_lora_config.to_dict()
lora_adapter_metadata.update({"foo": 1, "bar": 2})
for key, value in lora_adapter_metadata.items():
if isinstance(value, set):
lora_adapter_metadata[key] = list(value)
metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True)
safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata)
model.unload_lora()
self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly")
with self.assertRaises(TypeError) as err_context:
model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True)
self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception))
@require_torch_accelerator @require_torch_accelerator
def test_cpu_offload(self): def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment