Unverified Commit d185c0df authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Lora] correct lora saving & loading (#2655)

* [Lora] correct lora saving & loading

* fix final

* Apply suggestions from code review
parent 7c1b3477
...@@ -19,7 +19,7 @@ import torch ...@@ -19,7 +19,7 @@ import torch
from .models.cross_attention import LoRACrossAttnProcessor from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging
if is_safetensors_available(): if is_safetensors_available():
...@@ -150,13 +150,14 @@ class UNet2DConditionLoadersMixin: ...@@ -150,13 +150,14 @@ class UNet2DConditionLoadersMixin:
model_file = None model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict): if not isinstance(pretrained_model_name_or_path_or_dict, dict):
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"): # Let's first try to load .safetensors weights
if weight_name is None: if (is_safetensors_available() and weight_name is None) or (
weight_name = LORA_WEIGHT_NAME_SAFE weight_name is not None and weight_name.endswith(".safetensors")
):
try: try:
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
weights_name=weight_name, weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
...@@ -169,14 +170,13 @@ class UNet2DConditionLoadersMixin: ...@@ -169,14 +170,13 @@ class UNet2DConditionLoadersMixin:
) )
state_dict = safetensors.torch.load_file(model_file, device="cpu") state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError: except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE: # try loading non-safetensors weights
weight_name = None pass
if model_file is None: if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file( model_file = _get_model_file(
pretrained_model_name_or_path_or_dict, pretrained_model_name_or_path_or_dict,
weights_name=weight_name, weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir, cache_dir=cache_dir,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
...@@ -225,9 +225,10 @@ class UNet2DConditionLoadersMixin: ...@@ -225,9 +225,10 @@ class UNet2DConditionLoadersMixin:
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
is_main_process: bool = True, is_main_process: bool = True,
weights_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = False,
**kwargs,
): ):
r""" r"""
Save an attention processor to a directory, so that it can be re-loaded using the Save an attention processor to a directory, so that it can be re-loaded using the
...@@ -245,6 +246,12 @@ class UNet2DConditionLoadersMixin: ...@@ -245,6 +246,12 @@ class UNet2DConditionLoadersMixin:
need to replace `torch.save` by another method. Can be configured with the environment variable need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
""" """
weight_name = weight_name or deprecate(
"weights_name",
"0.18.0",
"`weights_name` is deprecated, please use `weight_name` instead.",
take_from=kwargs,
)
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
...@@ -265,22 +272,13 @@ class UNet2DConditionLoadersMixin: ...@@ -265,22 +272,13 @@ class UNet2DConditionLoadersMixin:
# Save the model # Save the model
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
# Clean the folder from a previous save if weight_name is None:
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
if weights_name is None:
if safe_serialization: if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE weight_name = LORA_WEIGHT_NAME_SAFE
else: else:
weights_name = LORA_WEIGHT_NAME weight_name = LORA_WEIGHT_NAME
# Save the model # Save the model
save_function(state_dict, os.path.join(save_directory, weights_name)) save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment