".github/vscode:/vscode.git/clone" did not exist on "af336d66944689e1eacd20a2a040e5fc56c31045"
Unverified Commit 2e69cf16 authored by Vladislav Artemyev's avatar Vladislav Artemyev Committed by GitHub
Browse files

Log global_step instead of epoch to tensorboard (#4493)


Co-authored-by: default avatarmrlzla <noname@noname.com>
parent 9c29bc2d
......@@ -107,7 +107,16 @@ DreamBooth for the text encoder was enabled: {train_text_encoder}.
def log_validation(
text_encoder, tokenizer, unet, vae, args, accelerator, weight_dtype, epoch, prompt_embeds, negative_prompt_embeds
text_encoder,
tokenizer,
unet,
vae,
args,
accelerator,
weight_dtype,
global_step,
prompt_embeds,
negative_prompt_embeds,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
......@@ -173,7 +182,7 @@ def log_validation(
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
......@@ -1308,7 +1317,7 @@ def main(args):
args,
accelerator,
weight_dtype,
epoch,
global_step,
validation_prompt_encoder_hidden_states,
validation_prompt_negative_prompt_embeds,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment