"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "e0e9538044906b3ac144e422d620f00e96d7d8b5"
Unverified Commit ebd44957 authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

image generation main process checks (#2631)

parent e2d9a9be
...@@ -1000,13 +1000,14 @@ def main(args): ...@@ -1000,13 +1000,14 @@ def main(args):
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1
if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process:
if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
......
...@@ -864,7 +864,7 @@ def main(): ...@@ -864,7 +864,7 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
......
...@@ -790,7 +790,7 @@ def main(): ...@@ -790,7 +790,7 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if accelerator.is_main_process and args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}." f" {args.validation_prompt}."
......
...@@ -800,20 +800,19 @@ def main(): ...@@ -800,20 +800,19 @@ def main():
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
) )
if accelerator.is_main_process: for tracker in accelerator.trackers:
for tracker in accelerator.trackers: if tracker.name == "tensorboard":
if tracker.name == "tensorboard": np_images = np.stack([np.asarray(img) for img in images])
np_images = np.stack([np.asarray(img) for img in images]) tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") if tracker.name == "wandb":
if tracker.name == "wandb": tracker.log(
tracker.log( {
{ "validation": [
"validation": [ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
for i, image in enumerate(images) ]
] }
} )
)
del pipeline del pipeline
torch.cuda.empty_cache() torch.cuda.empty_cache()
......
...@@ -843,13 +843,14 @@ def main(): ...@@ -843,13 +843,14 @@ def main():
save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin") save_path = os.path.join(args.output_dir, f"learned_embeds-steps-{global_step}.bin")
save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path) save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
if global_step % args.checkpointing_steps == 0: if accelerator.is_main_process:
if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path) accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}") logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch) if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs) progress_bar.set_postfix(**logs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment