Commit d3379f0b authored by Maruyama_Aya's avatar Maruyama_Aya
Browse files

fixed model saving bugs

parent b29e1f07
......@@ -667,9 +667,9 @@ def main(args):
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
if local_rank == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
if not os.path.exists(os.path.join(save_path, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
......
......@@ -693,9 +693,9 @@ def main(args):
if global_step % args.save_steps == 0:
torch.cuda.synchronize()
if local_rank == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
booster.save_model(unet, os.path.join(save_path, "diffusion_pytorch_model.bin"))
if local_rank == 0:
if not os.path.exists(os.path.join(save_path, "config.json")):
shutil.copy(os.path.join(args.pretrained_model_name_or_path, "unet/config.json"), save_path)
logger.info(f"Saving model checkpoint to {save_path}", ranks=[0])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment