Unverified Commit fb98acf0 authored by Oren WANG's avatar Oren WANG Committed by GitHub
Browse files

[lora] Fix bug with training without validation (#2106)

parent 180841bb
...@@ -984,19 +984,19 @@ def main(args): ...@@ -984,19 +984,19 @@ def main(args):
prompt = args.num_validation_images * [args.validation_prompt] prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images]) np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb": if tracker.name == "wandb":
tracker.log( tracker.log(
{ {
"test": [ "test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images) for i, image in enumerate(images)
] ]
} }
) )
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment