Unverified Commit 0e98e839 authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

[lora] Log images when using tensorboard (#2078)

* [lora] Log images when using tensorboard.

* Specify image format instead of transposing.

As discussed with @sayakpaul.

* Style
parent f4dddaf5
...@@ -22,6 +22,7 @@ import warnings ...@@ -22,6 +22,7 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -930,6 +931,9 @@ def main(args): ...@@ -930,6 +931,9 @@ def main(args):
images = pipeline(prompt, num_inference_steps=25, generator=generator).images images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb": if tracker.name == "wandb":
tracker.log( tracker.log(
{ {
...@@ -966,6 +970,9 @@ def main(args): ...@@ -966,6 +970,9 @@ def main(args):
images = pipeline(prompt, num_inference_steps=25, generator=generator).images images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb": if tracker.name == "wandb":
tracker.log( tracker.log(
{ {
......
...@@ -760,6 +760,9 @@ def main(): ...@@ -760,6 +760,9 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb": if tracker.name == "wandb":
tracker.log( tracker.log(
{ {
...@@ -807,6 +810,9 @@ def main(): ...@@ -807,6 +810,9 @@ def main():
if accelerator.is_main_process: if accelerator.is_main_process:
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb": if tracker.name == "wandb":
tracker.log( tracker.log(
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment