Unverified Commit 0db19da0 authored by Ben Evans's avatar Ben Evans Committed by GitHub
Browse files

Log Unconditional Image Generation Samples to W&B (#2287)



* Log Unconditional Image Generation Samples to WandB

* Check for wandb installation and parity between onnxruntime script

* Log epoch to wandb

* Check for tensorboard logger early on

* style fixes

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 62b3c9e0
......@@ -21,7 +21,7 @@ import diffusers
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
......@@ -220,6 +220,7 @@ def parse_args():
help="Whether the model should predict the 'epsilon'/noise error or directly the reconstructed image 'x0'.",
)
parser.add_argument("--ddpm_num_steps", type=int, default=1000)
parser.add_argument("--ddpm_num_inference_steps", type=int, default=1000)
parser.add_argument("--ddpm_beta_schedule", type=str, default="linear")
parser.add_argument(
"--checkpointing_steps",
......@@ -271,6 +272,15 @@ def main(args):
logging_dir=logging_dir,
)
if args.logger == "tensorboard":
if not is_tensorboard_available():
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
elif args.logger == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......@@ -552,7 +562,7 @@ def main(args):
generator=generator,
batch_size=args.eval_batch_size,
output_type="numpy",
num_inference_steps=args.ddpm_num_steps,
num_inference_steps=args.ddpm_num_inference_steps,
).images
# denormalize the images and save to tensorboard
......@@ -562,6 +572,11 @@ def main(args):
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
elif args.logger == "wandb":
accelerator.get_tracker("wandb").log(
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
step=global_step,
)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
......
......@@ -22,7 +22,7 @@ import diffusers
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_tensorboard_available
from diffusers.utils import check_min_version, is_tensorboard_available, is_wandb_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
......@@ -280,6 +280,15 @@ def main(args):
logging_dir=logging_dir,
)
if args.logger == "tensorboard":
if not is_tensorboard_available():
raise ImportError("Make sure to install tensorboard if you want to use it for logging during training.")
elif args.logger == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
......@@ -604,10 +613,15 @@ def main(args):
# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
if args.logger == "tensorboard" and is_tensorboard_available():
if args.logger == "tensorboard":
accelerator.get_tracker("tensorboard").add_images(
"test_samples", images_processed.transpose(0, 3, 1, 2), epoch
)
elif args.logger == "wandb":
accelerator.get_tracker("wandb").log(
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
step=global_step,
)
if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
# save the model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment