Unverified Commit 2f9a70aa authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[LoRA] Make sure validation works in multi GPU setup (#2172)

* [LoRA] Make sure validation works in multi GPU setup

* more fixes

* up
parent e43e206d
...@@ -923,44 +923,47 @@ def main(args): ...@@ -923,44 +923,47 @@ def main(args):
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if accelerator.is_main_process:
logger.info( if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" logger.info(
f" {args.validation_prompt}." f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
) f" {args.validation_prompt}."
# create pipeline )
pipeline = DiffusionPipeline.from_pretrained( # create pipeline
args.pretrained_model_name_or_path, pipeline = DiffusionPipeline.from_pretrained(
unet=accelerator.unwrap_model(unet), args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder), unet=accelerator.unwrap_model(unet),
revision=args.revision, text_encoder=accelerator.unwrap_model(text_encoder),
torch_dtype=weight_dtype, revision=args.revision,
) torch_dtype=weight_dtype,
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) )
pipeline = pipeline.to(accelerator.device) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline.set_progress_bar_config(disable=True) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) # run inference
prompt = args.num_validation_images * [args.validation_prompt] generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
images = pipeline(prompt, num_inference_steps=25, generator=generator).images images = [
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for tracker in accelerator.trackers: for _ in range(args.num_validation_images)
if tracker.name == "tensorboard": ]
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") for tracker in accelerator.trackers:
if tracker.name == "wandb": if tracker.name == "tensorboard":
tracker.log( np_images = np.stack([np.asarray(img) for img in images])
{ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
"validation": [ if tracker.name == "wandb":
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") tracker.log(
for i, image in enumerate(images) {
] "validation": [
} wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
) for i, image in enumerate(images)
]
del pipeline }
torch.cuda.empty_cache() )
del pipeline
torch.cuda.empty_cache()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
...@@ -982,8 +985,10 @@ def main(args): ...@@ -982,8 +985,10 @@ def main(args):
# run inference # run inference
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
prompt = args.num_validation_images * [args.validation_prompt] images = [
images = pipeline(prompt, num_inference_steps=25, generator=generator).images pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
......
...@@ -749,44 +749,47 @@ def main(): ...@@ -749,44 +749,47 @@ def main():
if global_step >= args.max_train_steps: if global_step >= args.max_train_steps:
break break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if accelerator.is_main_process:
logger.info( if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
f"Running validation... \n Generating {args.num_validation_images} images with prompt:" logger.info(
f" {args.validation_prompt}." f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
) f" {args.validation_prompt}."
# create pipeline )
pipeline = DiffusionPipeline.from_pretrained( # create pipeline
args.pretrained_model_name_or_path, pipeline = DiffusionPipeline.from_pretrained(
unet=accelerator.unwrap_model(unet), args.pretrained_model_name_or_path,
revision=args.revision, unet=accelerator.unwrap_model(unet),
torch_dtype=weight_dtype, revision=args.revision,
) torch_dtype=weight_dtype,
pipeline = pipeline.to(accelerator.device) )
pipeline.set_progress_bar_config(disable=True) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) # run inference
images = [] generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
for _ in range(args.num_validation_images): images = []
images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]) for _ in range(args.num_validation_images):
images.append(
if accelerator.is_main_process: pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
for tracker in accelerator.trackers: )
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images]) if accelerator.is_main_process:
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") for tracker in accelerator.trackers:
if tracker.name == "wandb": if tracker.name == "tensorboard":
tracker.log( np_images = np.stack([np.asarray(img) for img in images])
{ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
"validation": [ if tracker.name == "wandb":
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") tracker.log(
for i, image in enumerate(images) {
] "validation": [
} wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
) for i, image in enumerate(images)
]
del pipeline }
torch.cuda.empty_cache() )
del pipeline
torch.cuda.empty_cache()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment