Unverified Commit 2f9a70aa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[LoRA] Make sure validation works in multi GPU setup (#2172)

* [LoRA] Make sure validation works in multi GPU setup

* more fixes

* up
parent e43e206d
......@@ -923,6 +923,7 @@ def main(args):
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
......@@ -942,8 +943,10 @@ def main(args):
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
......@@ -982,8 +985,10 @@ def main(args):
# run inference
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
......
......@@ -749,6 +749,7 @@ def main():
if global_step >= args.max_train_steps:
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
......@@ -768,7 +769,9 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = []
for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
if accelerator.is_main_process:
for tracker in accelerator.trackers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment