Commit 06505ba4 authored by anton-l's avatar anton-l
Browse files

Less eval steps during training

parent 13457002
...@@ -147,9 +147,9 @@ def main(args): ...@@ -147,9 +147,9 @@ def main(args):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Generate a sample image for visual inspection # Generate sample images for visual inspection
if accelerator.is_main_process: if accelerator.is_main_process:
with torch.no_grad(): if epoch % args.save_images_epochs == 0 or epoch == args.num_epochs - 1:
pipeline = DDPMPipeline( pipeline = DDPMPipeline(
unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model), unet=accelerator.unwrap_model(ema_model.averaged_model if args.use_ema else model),
scheduler=noise_scheduler, scheduler=noise_scheduler,
...@@ -159,9 +159,11 @@ def main(args): ...@@ -159,9 +159,11 @@ def main(args):
# run pipeline in inference (sample random noise and denoise) # run pipeline in inference (sample random noise and denoise)
images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"] images = pipeline(generator=generator, batch_size=args.eval_batch_size, output_type="numpy")["sample"]
# denormalize the images and save to tensorboard # denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8") images_processed = (images * 255).round().astype("uint8")
accelerator.trackers[0].writer.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch) accelerator.trackers[0].writer.add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model # save the model
...@@ -184,7 +186,8 @@ if __name__ == "__main__": ...@@ -184,7 +186,8 @@ if __name__ == "__main__":
parser.add_argument("--train_batch_size", type=int, default=16) parser.add_argument("--train_batch_size", type=int, default=16)
parser.add_argument("--eval_batch_size", type=int, default=16) parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_model_epochs", type=int, default=5) parser.add_argument("--save_images_epochs", type=int, default=10)
parser.add_argument("--save_model_epochs", type=int, default=10)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--learning_rate", type=float, default=1e-4) parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--lr_scheduler", type=str, default="cosine") parser.add_argument("--lr_scheduler", type=str, default="cosine")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment