Unverified Commit a4b2c2f1 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

train_unconditional save restore unet parameters (#2706)

parent 77e0ea80
......@@ -625,8 +625,11 @@ def main(args):
if accelerator.is_main_process:
if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
unet = accelerator.unwrap_model(model)
if args.use_ema:
ema_model.store(unet.parameters())
ema_model.copy_to(unet.parameters())
pipeline = DDPMPipeline(
unet=unet,
scheduler=noise_scheduler,
......@@ -641,6 +644,9 @@ def main(args):
output_type="numpy",
).images
if args.use_ema:
ema_model.restore(unet.parameters())
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
......@@ -659,7 +665,22 @@ def main(args):
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
unet = accelerator.unwrap_model(model)
if args.use_ema:
ema_model.store(unet.parameters())
ema_model.copy_to(unet.parameters())
pipeline = DDPMPipeline(
unet=unet,
scheduler=noise_scheduler,
)
pipeline.save_pretrained(args.output_dir)
if args.use_ema:
ema_model.restore(unet.parameters())
if args.push_to_hub:
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment