Unverified Commit 2f9a70aa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[LoRA] Make sure validation works in multi GPU setup (#2172)

* [LoRA] Make sure validation works in multi GPU setup

* more fixes

* up
parent e43e206d
...@@ -923,6 +923,7 @@ def main(args): ...@@ -923,6 +923,7 @@ def main(args):
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
...@@ -942,8 +943,10 @@ def main(args): ...@@ -942,8 +943,10 @@ def main(args):
# run inference # run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt] images = [
images = pipeline(prompt, num_inference_steps=25, generator=generator).images pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
...@@ -982,8 +985,10 @@ def main(args): ...@@ -982,8 +985,10 @@ def main(args):
# run inference # run inference
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
prompt = args.num_validation_images * [args.validation_prompt] images = [
images = pipeline(prompt, num_inference_steps=25, generator=generator).images pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
......
...@@ -749,6 +749,7 @@ def main(): ...@@ -749,6 +749,7 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info( logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
...@@ -768,7 +769,9 @@ def main(): ...@@ -768,7 +769,9 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) images.append(
pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
)
if accelerator.is_main_process: if accelerator.is_main_process:
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment