Unverified Commit ae368e42 authored by Matthieu Bizien's avatar Matthieu Bizien Committed by GitHub
Browse files

[Proposal] Support saving to safetensors (#1494)

* Add parameter safe_serialization to DiffusionPipeline.save_pretrained

* Add option safe_serialization on ModelMixin.save_pretrained

* Add test test_save_safe_serialization

* Black

* Re-trigger the CI

* Fix doc-builder

* Validate files are saved as safetensor in test_save_safe_serialization
parent cf4664e8
...@@ -191,7 +191,8 @@ class ModelMixin(torch.nn.Module): ...@@ -191,7 +191,8 @@ class ModelMixin(torch.nn.Module):
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
is_main_process: bool = True, is_main_process: bool = True,
save_function: Callable = torch.save, save_function: Callable = None,
safe_serialization: bool = False,
): ):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
...@@ -206,12 +207,21 @@ class ModelMixin(torch.nn.Module): ...@@ -206,12 +207,21 @@ class ModelMixin(torch.nn.Module):
the main process to avoid race conditions. the main process to avoid race conditions.
save_function (`Callable`): save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
""" """
if safe_serialization and not is_safetensors_available():
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
if os.path.isfile(save_directory): if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file") logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return return
if save_function is None:
save_function = safetensors.torch.save_file if safe_serialization else torch.save
os.makedirs(save_directory, exist_ok=True) os.makedirs(save_directory, exist_ok=True)
model_to_save = self model_to_save = self
...@@ -224,18 +234,21 @@ class ModelMixin(torch.nn.Module): ...@@ -224,18 +234,21 @@ class ModelMixin(torch.nn.Module):
# Save the model # Save the model
state_dict = model_to_save.state_dict() state_dict = model_to_save.state_dict()
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
# Clean the folder from a previous save # Clean the folder from a previous save
for filename in os.listdir(save_directory): for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename) full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions. # in distributed settings to avoid race conditions.
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process: weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename) os.remove(full_filename)
# Save the model # Save the model
save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) save_function(state_dict, os.path.join(save_directory, weights_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
@classmethod @classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
......
...@@ -188,7 +188,11 @@ class DiffusionPipeline(ConfigMixin): ...@@ -188,7 +188,11 @@ class DiffusionPipeline(ConfigMixin):
# set models # set models
setattr(self, name, module) setattr(self, name, module)
def save_pretrained(self, save_directory: Union[str, os.PathLike]): def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
):
""" """
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
...@@ -197,6 +201,8 @@ class DiffusionPipeline(ConfigMixin): ...@@ -197,6 +201,8 @@ class DiffusionPipeline(ConfigMixin):
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist. Directory to which to save. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
""" """
self.save_config(save_directory) self.save_config(save_directory)
...@@ -234,6 +240,15 @@ class DiffusionPipeline(ConfigMixin): ...@@ -234,6 +240,15 @@ class DiffusionPipeline(ConfigMixin):
break break
save_method = getattr(sub_model, save_method_name) save_method = getattr(sub_model, save_method_name)
# Call the save method with the argument safe_serialization only if it's supported
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
if save_method_accept_safe:
save_method(
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
)
else:
save_method(os.path.join(save_directory, pipeline_component_name)) save_method(os.path.join(save_directory, pipeline_component_name))
def to(self, torch_device: Optional[Union[str, torch.device]] = None): def to(self, torch_device: Optional[Union[str, torch.device]] = None):
......
...@@ -25,6 +25,8 @@ import numpy as np ...@@ -25,6 +25,8 @@ import numpy as np
import torch import torch
import PIL import PIL
import safetensors.torch
import transformers
from diffusers import ( from diffusers import (
AutoencoderKL, AutoencoderKL,
DDIMPipeline, DDIMPipeline,
...@@ -537,6 +539,34 @@ class PipelineFastTests(unittest.TestCase): ...@@ -537,6 +539,34 @@ class PipelineFastTests(unittest.TestCase):
assert dict(ddim_config) == dict(ddim_config_2) assert dict(ddim_config) == dict(ddim_config_2)
def test_save_safe_serialization(self):
pipeline = StableDiffusionPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-torch")
with tempfile.TemporaryDirectory() as tmpdirname:
pipeline.save_pretrained(tmpdirname, safe_serialization=True)
# Validate that the VAE safetensor exists and are of the correct format
vae_path = os.path.join(tmpdirname, "vae", "diffusion_pytorch_model.safetensors")
assert os.path.exists(vae_path), f"Could not find {vae_path}"
_ = safetensors.torch.load_file(vae_path)
# Validate that the UNet safetensor exists and are of the correct format
unet_path = os.path.join(tmpdirname, "unet", "diffusion_pytorch_model.safetensors")
assert os.path.exists(unet_path), f"Could not find {unet_path}"
_ = safetensors.torch.load_file(unet_path)
# Validate that the text encoder safetensor exists and are of the correct format
text_encoder_path = os.path.join(tmpdirname, "text_encoder", "model.safetensors")
if transformers.__version__ >= "4.25.1":
assert os.path.exists(text_encoder_path), f"Could not find {text_encoder_path}"
_ = safetensors.torch.load_file(text_encoder_path)
pipeline = StableDiffusionPipeline.from_pretrained(tmpdirname)
assert pipeline.unet is not None
assert pipeline.vae is not None
assert pipeline.text_encoder is not None
assert pipeline.scheduler is not None
assert pipeline.feature_extractor is not None
def test_optional_components(self): def test_optional_components(self):
unet = self.dummy_cond_unet() unet = self.dummy_cond_unet()
pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment