Unverified Commit ebd44957 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

image generation main process checks (#2631)

parent e2d9a9be
......@@ -1000,11 +1000,12 @@ def main(args):
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
......
......@@ -864,7 +864,7 @@ def main():
if global_step >= args.max_train_steps:
break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
......
......@@ -790,7 +790,7 @@ def main():
if global_step >= args.max_train_steps:
break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
......
......@@ -800,7 +800,6 @@ def main():
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
if accelerator.is_main_process:
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
......
......@@ -843,11 +843,12 @@ def main():
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment