Unverified Commit 0e98e839 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

[lora] Log images when using tensorboard (#2078)

* [lora] Log images when using tensorboard.

* Specify image format instead of transposing.

As discussed with @sayakpaul.

* Style
parent f4dddaf5
......@@ -22,6 +22,7 @@ import warnings
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
......@@ -930,6 +931,9 @@ def main(args):
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
......@@ -966,6 +970,9 @@ def main(args):
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
......
......@@ -760,6 +760,9 @@ def main():
if accelerator.is_main_process:
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
......@@ -807,6 +810,9 @@ def main():
if accelerator.is_main_process:
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment