Unverified Commit 029fb416 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Safetensors] Make safetensors the default way of saving weights (#4235)



* make safetensors default

* set default save method as safetensors

* update tests

* update to support saving safetensors

* update test to account for safetensors default

* update example tests to use safetensors

* update example to support safetensors

* update unet tests for safetensors

* fix failing loader tests

* fix qc issues

* fix pipeline tests

* fix example test

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 852dc76d
...@@ -26,6 +26,7 @@ import warnings ...@@ -26,6 +26,7 @@ import warnings
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import safetensors
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -296,14 +297,19 @@ class CustomDiffusionDataset(Dataset): ...@@ -296,14 +297,19 @@ class CustomDiffusionDataset(Dataset):
return example return example
def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir): def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
"""Saves the new token embeddings from the text encoder.""" """Saves the new token embeddings from the text encoder."""
logger.info("Saving embeddings") logger.info("Saving embeddings")
learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
for x, y in zip(modifier_token_id, args.modifier_token): for x, y in zip(modifier_token_id, args.modifier_token):
learned_embeds_dict = {} learned_embeds_dict = {}
learned_embeds_dict[y] = learned_embeds[x] learned_embeds_dict[y] = learned_embeds[x]
torch.save(learned_embeds_dict, f"{output_dir}/{y}.bin") filename = f"{output_dir}/{y}.bin"
if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, filename)
def parse_args(input_args=None): def parse_args(input_args=None):
...@@ -605,6 +611,11 @@ def parse_args(input_args=None): ...@@ -605,6 +611,11 @@ def parse_args(input_args=None):
action="store_true", action="store_true",
help="Dont apply augmentation during data augmentation when this flag is enabled.", help="Dont apply augmentation during data augmentation when this flag is enabled.",
) )
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)
if input_args is not None: if input_args is not None:
args = parser.parse_args(input_args) args = parser.parse_args(input_args)
...@@ -1244,8 +1255,15 @@ def main(args): ...@@ -1244,8 +1255,15 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unet.to(torch.float32) unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir) unet.save_attn_procs(args.output_dir, safe_serialization=not args.no_safe_serialization)
save_new_embed(text_encoder, modifier_token_id, accelerator, args, args.output_dir) save_new_embed(
text_encoder,
modifier_token_id,
accelerator,
args,
args.output_dir,
safe_serialization=not args.no_safe_serialization,
)
# Final inference # Final inference
# Load previous pipeline # Load previous pipeline
...@@ -1256,9 +1274,15 @@ def main(args): ...@@ -1256,9 +1274,15 @@ def main(args):
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
# load attention processors # load attention processors
pipeline.unet.load_attn_procs(args.output_dir, weight_name="pytorch_custom_diffusion_weights.bin") weight_name = (
"pytorch_custom_diffusion_weights.safetensors"
if not args.no_safe_serialization
else "pytorch_custom_diffusion_weights.bin"
)
pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
for token in args.modifier_token: for token in args.modifier_token:
pipeline.load_textual_inversion(args.output_dir, weight_name=f"{token}.bin") token_weight_name = f"{token}.safetensors" if not args.no_safe_serialization else f"{token}.bin"
pipeline.load_textual_inversion(args.output_dir, weight_name=token_weight_name)
# run inference # run inference
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
......
...@@ -1374,7 +1374,7 @@ def main(args): ...@@ -1374,7 +1374,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.bin") pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
# run inference # run inference
images = [] images = []
......
...@@ -23,7 +23,7 @@ import tempfile ...@@ -23,7 +23,7 @@ import tempfile
import unittest import unittest
from typing import List from typing import List
import torch import safetensors
from accelerate.utils import write_basic_config from accelerate.utils import write_basic_config
from diffusers import DiffusionPipeline, UNet2DConditionModel from diffusers import DiffusionPipeline, UNet2DConditionModel
...@@ -93,7 +93,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -93,7 +93,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args, return_stdout=True) run_command(self._launch_args + test_args, return_stdout=True)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_textual_inversion(self): def test_textual_inversion(self):
...@@ -144,7 +144,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -144,7 +144,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_if(self): def test_dreambooth_if(self):
...@@ -170,7 +170,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -170,7 +170,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_dreambooth_checkpointing(self): def test_dreambooth_checkpointing(self):
...@@ -272,10 +272,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -272,10 +272,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters. # make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys()) is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora) self.assertTrue(is_lora)
...@@ -305,10 +305,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -305,10 +305,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# check `text_encoder` is present at all. # check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
keys = lora_state_dict.keys() keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys) is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present) self.assertTrue(is_text_encoder_present)
...@@ -341,10 +341,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -341,10 +341,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters. # make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys()) is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora) self.assertTrue(is_lora)
...@@ -373,10 +373,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -373,10 +373,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters. # make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys()) is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora) self.assertTrue(is_lora)
...@@ -406,10 +406,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -406,10 +406,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters. # make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys()) is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora) self.assertTrue(is_lora)
...@@ -437,6 +437,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -437,6 +437,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--lr_scheduler constant --lr_scheduler constant
--lr_warmup_steps 0 --lr_warmup_steps 0
--modifier_token <new1> --modifier_token <new1>
--no_safe_serialization
--output_dir {tmpdir} --output_dir {tmpdir}
""".split() """.split()
...@@ -466,7 +467,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -466,7 +467,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_text_to_image_checkpointing(self): def test_text_to_image_checkpointing(self):
...@@ -778,7 +779,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -778,7 +779,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors")))
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
...@@ -1373,7 +1374,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1373,7 +1374,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
def test_custom_diffusion_checkpointing_checkpoints_total_limit(self): def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
...@@ -1390,6 +1391,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1390,6 +1391,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--max_train_steps=6 --max_train_steps=6
--checkpoints_total_limit=2 --checkpoints_total_limit=2
--checkpointing_steps=2 --checkpointing_steps=2
--no_safe_serialization
""".split() """.split()
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
...@@ -1413,6 +1415,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1413,6 +1415,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--dataloader_num_workers=0 --dataloader_num_workers=0
--max_train_steps=9 --max_train_steps=9
--checkpointing_steps=2 --checkpointing_steps=2
--no_safe_serialization
""".split() """.split()
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
...@@ -1436,6 +1439,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1436,6 +1439,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
--checkpointing_steps=2 --checkpointing_steps=2
--resume_from_checkpoint=checkpoint-8 --resume_from_checkpoint=checkpoint-8
--checkpoints_total_limit=3 --checkpoints_total_limit=3
--no_safe_serialization
""".split() """.split()
run_command(self._launch_args + resume_run_args) run_command(self._launch_args + resume_run_args)
...@@ -1464,10 +1468,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1464,10 +1468,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters. # make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys()) is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora) self.assertTrue(is_lora)
...@@ -1491,10 +1495,10 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -1491,10 +1495,10 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters. # make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin")) lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys()) is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora) self.assertTrue(is_lora)
......
...@@ -24,6 +24,7 @@ from pathlib import Path ...@@ -24,6 +24,7 @@ from pathlib import Path
import numpy as np import numpy as np
import PIL import PIL
import safetensors
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -157,7 +158,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight ...@@ -157,7 +158,7 @@ def log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight
return images return images
def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path): def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path, safe_serialization=True):
logger.info("Saving embeddings") logger.info("Saving embeddings")
learned_embeds = ( learned_embeds = (
accelerator.unwrap_model(text_encoder) accelerator.unwrap_model(text_encoder)
...@@ -165,7 +166,11 @@ def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_p ...@@ -165,7 +166,11 @@ def save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_p
.weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] .weight[min(placeholder_token_ids) : max(placeholder_token_ids) + 1]
) )
learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()} learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
torch.save(learned_embeds_dict, save_path)
if safe_serialization:
safetensors.torch.save_file(learned_embeds_dict, save_path, metadata={"format": "pt"})
else:
torch.save(learned_embeds_dict, save_path)
def parse_args(): def parse_args():
...@@ -409,6 +414,11 @@ def parse_args(): ...@@ -409,6 +414,11 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
) )
parser.add_argument(
"--no_safe_serialization",
action="store_true",
help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
)
args = parser.parse_args() args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
...@@ -878,7 +888,14 @@ def main(): ...@@ -878,7 +888,14 @@ def main():
global_step += 1 global_step += 1
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path) save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)
if accelerator.is_main_process: if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0: if global_step % args.checkpointing_steps == 0:
...@@ -936,7 +953,14 @@ def main(): ...@@ -936,7 +953,14 @@ def main():
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings # Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin") save_path = os.path.join(args.output_dir, "learned_embeds.bin")
save_progress(text_encoder, placeholder_token_ids, accelerator, args, save_path) save_progress(
text_encoder,
placeholder_token_ids,
accelerator,
args,
save_path,
safe_serialization=not args.no_safe_serialization,
)
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
......
...@@ -497,7 +497,8 @@ class UNet2DConditionLoadersMixin: ...@@ -497,7 +497,8 @@ class UNet2DConditionLoadersMixin:
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = True,
**kwargs,
): ):
r""" r"""
Save an attention processor to a directory so that it can be reloaded using the Save an attention processor to a directory so that it can be reloaded using the
...@@ -514,7 +515,8 @@ class UNet2DConditionLoadersMixin: ...@@ -514,7 +515,8 @@ class UNet2DConditionLoadersMixin:
The function to use to save the state dictionary. Useful during distributed training when you need to The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
""" """
from .models.attention_processor import ( from .models.attention_processor import (
CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor,
...@@ -1414,7 +1416,7 @@ class LoraLoaderMixin: ...@@ -1414,7 +1416,7 @@ class LoraLoaderMixin:
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = True,
): ):
r""" r"""
Save the LoRA parameters corresponding to the UNet and text encoder. Save the LoRA parameters corresponding to the UNet and text encoder.
...@@ -1435,6 +1437,8 @@ class LoraLoaderMixin: ...@@ -1435,6 +1437,8 @@ class LoraLoaderMixin:
The function to use to save the state dictionary. Useful during distributed training when you need to The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
""" """
# Create a flat dictionary. # Create a flat dictionary.
state_dict = {} state_dict = {}
......
...@@ -272,7 +272,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -272,7 +272,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
is_main_process: bool = True, is_main_process: bool = True,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
push_to_hub: bool = False, push_to_hub: bool = False,
**kwargs, **kwargs,
...@@ -292,7 +292,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): ...@@ -292,7 +292,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
The function to use to save the state dictionary. Useful during distributed training when you need to The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
......
...@@ -77,7 +77,7 @@ class MultiControlNetModel(ModelMixin): ...@@ -77,7 +77,7 @@ class MultiControlNetModel(ModelMixin):
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
is_main_process: bool = True, is_main_process: bool = True,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
): ):
""" """
...@@ -95,7 +95,7 @@ class MultiControlNetModel(ModelMixin): ...@@ -95,7 +95,7 @@ class MultiControlNetModel(ModelMixin):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`. `DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format pytorch_model.<variant>.bin. If specified, weights are saved in the format pytorch_model.<variant>.bin.
......
...@@ -556,7 +556,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -556,7 +556,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
def save_pretrained( def save_pretrained(
self, self,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
safe_serialization: bool = False, safe_serialization: bool = True,
variant: Optional[str] = None, variant: Optional[str] = None,
push_to_hub: bool = False, push_to_hub: bool = False,
**kwargs, **kwargs,
...@@ -569,7 +569,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -569,7 +569,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
Directory to save a pipeline to. Will be created if it doesn't exist. Directory to save a pipeline to. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`): safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*): variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`. If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
......
...@@ -904,7 +904,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad ...@@ -904,7 +904,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = True,
): ):
state_dict = {} state_dict = {}
......
...@@ -1058,7 +1058,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L ...@@ -1058,7 +1058,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = True,
): ):
state_dict = {} state_dict = {}
......
...@@ -1338,7 +1338,7 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS ...@@ -1338,7 +1338,7 @@ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromS
is_main_process: bool = True, is_main_process: bool = True,
weight_name: str = None, weight_name: str = None,
save_function: Callable = None, save_function: Callable = None,
safe_serialization: bool = False, safe_serialization: bool = True,
): ):
state_dict = {} state_dict = {}
......
...@@ -201,7 +201,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -201,7 +201,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"], unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
def test_lora_save_load(self): def test_lora_save_load(self):
pipeline_components, lora_components = self.get_dummy_components() pipeline_components, lora_components = self.get_dummy_components()
...@@ -220,33 +220,6 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -220,33 +220,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"], unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def test_lora_save_load_safetensors(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
LoraLoaderMixin.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
safe_serialization=True,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname) sd_pipe.load_lora_weights(tmpdirname)
...@@ -256,7 +229,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -256,7 +229,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
# Outputs shouldn't match. # Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def test_lora_save_load_legacy(self): def test_lora_save_load_no_safe_serialization(self):
pipeline_components, lora_components = self.get_dummy_components() pipeline_components, lora_components = self.get_dummy_components()
unet_lora_attn_procs = lora_components["unet_lora_attn_procs"] unet_lora_attn_procs = lora_components["unet_lora_attn_procs"]
sd_pipe = StableDiffusionPipeline(**pipeline_components) sd_pipe = StableDiffusionPipeline(**pipeline_components)
...@@ -271,7 +244,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -271,7 +244,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
unet = sd_pipe.unet unet = sd_pipe.unet
unet.set_attn_processor(unet_lora_attn_procs) unet.set_attn_processor(unet_lora_attn_procs)
unet.save_attn_procs(tmpdirname) unet.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname) sd_pipe.load_lora_weights(tmpdirname)
...@@ -368,7 +341,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -368,7 +341,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"], unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname) sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images lora_images = sd_pipe(**pipeline_inputs).images
...@@ -425,7 +398,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -425,7 +398,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"], unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname) sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
...@@ -501,7 +474,7 @@ class LoraLoaderMixinTests(unittest.TestCase): ...@@ -501,7 +474,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
unet_lora_layers=lora_components["unet_lora_layers"], unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_lora_layers"],
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname) sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images lora_images = sd_pipe(**pipeline_inputs).images
...@@ -629,7 +602,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -629,7 +602,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname) sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images lora_images = sd_pipe(**pipeline_inputs).images
...@@ -658,7 +631,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase): ...@@ -658,7 +631,7 @@ class SDXLLoraLoaderMixinTests(unittest.TestCase):
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
) )
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
sd_pipe.load_lora_weights(tmpdirname) sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
......
...@@ -52,7 +52,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout): ...@@ -52,7 +52,7 @@ def _test_from_save_pretrained_dynamo(in_queue, out_queue, timeout):
model = torch.compile(model) model = torch.compile(model)
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = model_class.from_pretrained(tmpdirname) new_model = model_class.from_pretrained(tmpdirname)
new_model.to(torch_device) new_model.to(torch_device)
...@@ -205,7 +205,7 @@ class ModelTesterMixin: ...@@ -205,7 +205,7 @@ class ModelTesterMixin:
model.eval() model.eval()
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname)
if hasattr(new_model, "set_default_attn_processor"): if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor() new_model.set_default_attn_processor()
...@@ -327,7 +327,7 @@ class ModelTesterMixin: ...@@ -327,7 +327,7 @@ class ModelTesterMixin:
model.eval() model.eval()
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, variant="fp16") model.save_pretrained(tmpdirname, variant="fp16", safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16") new_model = self.model_class.from_pretrained(tmpdirname, variant="fp16")
if hasattr(new_model, "set_default_attn_processor"): if hasattr(new_model, "set_default_attn_processor"):
new_model.set_default_attn_processor() new_model.set_default_attn_processor()
...@@ -372,7 +372,7 @@ class ModelTesterMixin: ...@@ -372,7 +372,7 @@ class ModelTesterMixin:
continue continue
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.to(dtype) model.to(dtype)
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype) new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
assert new_model.dtype == dtype assert new_model.dtype == dtype
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype) new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
...@@ -429,7 +429,7 @@ class ModelTesterMixin: ...@@ -429,7 +429,7 @@ class ModelTesterMixin:
# test if the model can be loaded from the config # test if the model can be loaded from the config
# and has all the expected shape # and has all the expected shape
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname) model.save_pretrained(tmpdirname, safe_serialization=False)
new_model = self.model_class.from_pretrained(tmpdirname) new_model = self.model_class.from_pretrained(tmpdirname)
new_model.to(torch_device) new_model.to(torch_device)
new_model.eval() new_model.eval()
......
...@@ -579,7 +579,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -579,7 +579,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
...@@ -643,12 +643,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -643,12 +643,12 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename # Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname, safe_serialization=True)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
new_model.to(torch_device) new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors")
def test_lora_save_torch_force_load_safetensors_error(self): def test_lora_save_torch_force_load_safetensors_error(self):
# enable deterministic behavior for gradient checkpointing # enable deterministic behavior for gradient checkpointing
...@@ -664,7 +664,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -664,7 +664,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename # Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
...@@ -775,7 +775,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -775,7 +775,7 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample = model(**inputs_dict).sample sample = model(**inputs_dict).sample
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
......
...@@ -252,7 +252,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -252,7 +252,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
...@@ -316,11 +316,11 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -316,11 +316,11 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
# Saving as torch, properly reloads with directly filename # Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
new_model.to(torch_device) new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.bin") new_model.load_attn_procs(tmpdirname, weight_name="pytorch_lora_weights.safetensors")
def test_lora_save_torch_force_load_safetensors_error(self): def test_lora_save_torch_force_load_safetensors_error(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -335,7 +335,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -335,7 +335,7 @@ class UNet3DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
model.set_attn_processor(lora_attn_procs) model.set_attn_processor(lora_attn_procs)
# Saving as torch, properly reloads with directly filename # Saving as torch, properly reloads with directly filename
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname) model.save_attn_procs(tmpdirname, safe_serialization=False)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
torch.manual_seed(0) torch.manual_seed(0)
new_model = self.model_class(**init_dict) new_model = self.model_class(**init_dict)
......
...@@ -884,7 +884,7 @@ class CustomPipelineTests(unittest.TestCase): ...@@ -884,7 +884,7 @@ class CustomPipelineTests(unittest.TestCase):
) )
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname) pipe.save_pretrained(tmpdirname, safe_serialization=False)
pipe_new = CustomPipeline.from_pretrained(tmpdirname) pipe_new = CustomPipeline.from_pretrained(tmpdirname)
pipe_new.save_pretrained(tmpdirname) pipe_new.save_pretrained(tmpdirname)
......
...@@ -309,7 +309,7 @@ class PipelineTesterMixin: ...@@ -309,7 +309,7 @@ class PipelineTesterMixin:
logger.setLevel(diffusers.logging.INFO) logger.setLevel(diffusers.logging.INFO)
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir) pipe.save_pretrained(tmpdir, safe_serialization=False)
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
...@@ -597,7 +597,7 @@ class PipelineTesterMixin: ...@@ -597,7 +597,7 @@ class PipelineTesterMixin:
output = pipe(**inputs)[0] output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir) pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
pipe_loaded.to(torch_device) pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None) pipe_loaded.set_progress_bar_config(disable=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment