"docs/vscode:/vscode.git/clone" did not exist on "e2db2eddbb1699a59fbb5ccbec912979048ef3bf"
Unverified Commit 029fb416 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Safetensors] Make safetensors the default way of saving weights (#4235)



* make safetensors default

* set default save method as safetensors

* update tests

* update to support saving safetensors

* update test to account for safetensors default

* update example tests to use safetensors

* update example to support safetensors

* update unet tests for safetensors

* fix failing loader tests

* fix qc issues

* fix pipeline tests

* fix example test

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 852dc76d
......@@ -26,6 +26,7 @@ import warnings
from pathlib import Path
import numpy as np
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
......@@ -296,14 +297,19 @@ class CustomDiffusionDataset(Dataset):
return example
def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir):
def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
"""Saves the new token embeddings from the text encoder."""
logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
for x, y in zip(modifier_token_id, args.modifier_token):
learned_embeds_dict = {}
learned_embeds_dict[y] = learned_embeds[x]
torch.save(learned_embeds_dict, f"{output_dir}/{y}.bin")
filename = f"{output_dir}/{y}.bin"
if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, filename)
def parse_args(input_args=None):
......@@ -605,6 +611,11 @@ def parse_args(input_args=None):
action="store_true",
help="Dont apply augmentation during data augmentation when this flag is enabled.",
)
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)
if input_args is not None:
args = parser.parse_args(input_args)
......@@ -1244,8 +1255,15 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.output_dir)
unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
save_new_embed(
text_encoder,
modifier_token_id,
accelerator,
args,
args.output_dir,
safe_serialization=not args.no_safe_serialization,
)
# Final inference
# Load previous pipeline
......@@ -1256,9 +1274,15 @@ def main(args):
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir, weight_name="pytorch_custom_diffusion_weights.bin")
weight_name = (
"pytorch_custom_diffusion_weights.safetensors"
if not args.no_safe_serialization
else "pytorch_custom_diffusion_weights.bin"
)
pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
for token in args.modifier_token:
pipeline.load_textual_inversion(args.output_dir, weight_name=f"{token}.bin")
token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin"
pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)
# run inference
if args.validation_prompt and args.num_validation_images > 0:
......
......@@ -1374,7 +1374,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin")
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
# run inference
images = []
......
......@@ -23,7 +23,7 @@ import tempfile
import unittest
from typing import List
import torch
import safetensors
from accelerate.utils import write_basic_config
from diffusers import DiffusionPipeline, UNet2DConditionModel
......@@ -93,7 +93,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_textual_inversion(self):
......@@ -144,7 +144,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_if(self):
......@@ -170,7 +170,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self):
......@@ -272,10 +272,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
......@@ -305,10 +305,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)
......@@ -341,10 +341,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
......@@ -373,10 +373,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
......@@ -406,10 +406,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
......@@ -437,6 +437,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--lr_scheduler constant
--lr_warmup_steps 0
--modifier_token <new1>
--no_safe_serialization
--output_dir {tmpdir}
""".split()
......@@ -466,7 +467,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_text_to_image_checkpointing(self):
......@@ -778,7 +779,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
......@@ -1373,7 +1374,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
......@@ -1390,6 +1391,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--max_train_steps=6
--checkpoints_total_limit=2
--checkpointing_steps=2
--no_safe_serialization
""".split()
run_command(self._launch_args + test_args)
......@@ -1413,6 +1415,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--dataloader_num_workers=0
--max_train_steps=9
--checkpointing_steps=2
--no_safe_serialization
""".split()
run_command(self._launch_args + test_args)
......@@ -1436,6 +1439,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3
--no_safe_serialization
""".split()
run_command(self._launch_args + resume_run_args)
......@@ -1464,10 +1468,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
......@@ -1491,10 +1495,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
......
......@@ -24,6 +24,7 @@ from pathlib import Path
import numpy as np
import PIL
import safetensors
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
......@@ -157,7 +158,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
return images
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path):
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
logger.info("Saving embeddings")
learned_embeds = (
accelerator.unwrap_model(text_encoder)
......@@ -165,7 +166,11 @@ def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_p
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
)
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, save_path)
def parse_args():
......@@ -409,6 +414,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
......@@ -878,7 +888,14 @@ def main():
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
......@@ -936,7 +953,14 @@ def main():
pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path)
save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)
if args.push_to_hub:
save_model_card(
......
......@@ -497,7 +497,8 @@ class UNet2DConditionLoadersMixin:
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
**kwargs,
):
r"""
Save an attention processor to a directory so that it can be reloaded using the
......@@ -514,7 +515,8 @@ class UNet2DConditionLoadersMixin:
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
from .models.attention_processor import (
CustomDiffusionAttnProcessor,
......@@ -1414,7 +1416,7 @@ class LoraLoaderMixin:
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
......@@ -1435,6 +1437,8 @@ class LoraLoaderMixin:
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
# Create a flat dictionary.
state_dict = {}
......
......@@ -272,7 +272,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
......@@ -292,7 +292,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
......
......@@ -77,7 +77,7 @@ class MultiControlNetModel(ModelMixin):
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
variant: Optional[str] = None,
):
"""
......@@ -95,7 +95,7 @@ class MultiControlNetModel(ModelMixin):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin.
......
......@@ -556,7 +556,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
safe_serialization: bool = True,
variant: Optional[str] = None,
push_to_hub: bool = False,
**kwargs,
......@@ -569,7 +569,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save a pipeline to. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`):
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
......
......@@ -904,7 +904,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
):
state_dict = {}
......
......@@ -1058,7 +1058,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
):
state_dict = {}
......
......@@ -1338,7 +1338,7 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
safe_serialization: bool = True,
):
state_dict = {}
......
......@@ -201,7 +201,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
def test_lora_save_load(self):
pipeline_components, lora_components = self.get_dummy_components()
......@@ -220,33 +220,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def test_lora_save_load_safetensors(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -256,7 +229,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def test_lora_save_load_legacy(self):
def test_lora_save_load_no_safe_serialization(self):
pipeline_components, lora_components = self.get_dummy_components()
unet_lora_attn_procs = lora_components["unet_lora_attn_procs"]
sd_pipe = StableDiffusionPipeline(**pipeline_components)
......@@ -271,7 +244,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname:
unet = sd_pipe.unet
unet.set_attn_processor(unet_lora_attn_procs)
unet.save_attn_procs(tmpdirname)
unet.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
......@@ -368,7 +341,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
......@@ -425,7 +398,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
......@@ -501,7 +474,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
......@@ -629,7 +602,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
......@@ -658,7 +631,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
......
......@@ -52,7 +52,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
model = torch.compile(model)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
......@@ -205,7 +205,7 @@ class ModelTesterMixin:
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname)
if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor()
......@@ -327,7 +327,7 @@ class ModelTesterMixin:
model.eval()
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16")
model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor()
......@@ -372,7 +372,7 @@ class ModelTesterMixin:
continue
with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype)
model.save_pretrained(tmpdirname)
model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
......@@ -429,7 +429,7 @@ class ModelTesterMixin:
# test if the model can be loaded from the config
# and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device)
new_model.eval()
......
......@@ -579,7 +579,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
......@@ -643,12 +643,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors")
def test_lora_save_torch_force_load_safetensors_error(self):
# enable deterministic behavior for gradient checkpointing
......@@ -664,7 +664,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
......@@ -775,7 +775,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
......
......@@ -252,7 +252,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
......@@ -316,11 +316,11 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin")
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors")
def test_lora_save_torch_force_load_safetensors_error(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
......@@ -335,7 +335,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
......
......@@ -884,7 +884,7 @@ class CustomPipelineTests(unittest.TestCase):
)
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe.save_pretrained(tmpdirname, safe_serialization=False)
pipe_new = CustomPipeline.from_pretrained(tmpdirname)
pipe_new.save_pretrained(tmpdirname)
......
......@@ -309,7 +309,7 @@ class PipelineTesterMixin:
logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
......@@ -597,7 +597,7 @@ class PipelineTesterMixin:
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir)
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment