Unverified Commit 03135898 authored by Yead's avatar Yead Committed by GitHub
Browse files

Fix save_path bug in textual inversion training script (#4710)



* Update textual_inversion.py

fixed safe_path bug in textual inversion training

* Update test_examples.py

update test_textual_inversion for updating saved file's name

* Update textual_inversion.py

fixed some formatting issues

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent fd35689f
...@@ -122,7 +122,7 @@ class ExamplesTestsAccelerate(unittest.TestCase): ...@@ -122,7 +122,7 @@ class ExamplesTestsAccelerate(unittest.TestCase):
run_command(self._launch_args + test_args) run_command(self._launch_args + test_args)
# save_pretrained smoke test # save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.bin"))) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "learned_embeds.safetensors")))
def test_dreambooth(self): def test_dreambooth(self):
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
......
...@@ -887,7 +887,12 @@ def main(): ...@@ -887,7 +887,12 @@ def main():
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") weight_name = (
f"learned_embeds-steps-{global_step}.bin"
if args.no_safe_serialization
else f"learned_embeds-steps-{global_step}.safetensors"
)
save_path = os.path.join(args.output_dir, weight_name)
save_progress( save_progress(
text_encoder, text_encoder,
placeholder_token_ids, placeholder_token_ids,
...@@ -952,7 +957,8 @@ def main(): ...@@ -952,7 +957,8 @@ def main():
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
# Save the newly trained embeddings # Save the newly trained embeddings
save_path = os.path.join(args.output_dir, "learned_embeds.bin") weight_name = "learned_embeds.bin" if args.no_safe_serialization else "learned_embeds.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress( save_progress(
text_encoder, text_encoder,
placeholder_token_ids, placeholder_token_ids,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment