Unverified Commit 7f161871 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Modularize Dreambooth LoRA SDXL inferencing during and after training (#6655)



* modularize log validation

* run make style

* revert import wandb

* fix code quality & import wandb

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent f11b922b
...@@ -67,6 +67,9 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -67,6 +67,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0") check_min_version("0.27.0.dev0")
...@@ -140,6 +143,61 @@ Weights for this model are available in Safetensors format. ...@@ -140,6 +143,61 @@ Weights for this model are available in Safetensors format.
model_card.save(os.path.join(repo_folder, "README.md")) model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
with torch.cuda.amp.autocast():
images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path( def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
...@@ -862,7 +920,6 @@ def main(args): ...@@ -862,7 +920,6 @@ def main(args):
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
...@@ -1615,10 +1672,6 @@ def main(args): ...@@ -1615,10 +1672,6 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline # create pipeline
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained( text_encoder_one = text_encoder_cls_one.from_pretrained(
...@@ -1644,51 +1697,16 @@ def main(args): ...@@ -1644,51 +1697,16 @@ def main(args):
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
with torch.cuda.amp.autocast(): images = log_validation(
images = [ pipeline,
pipeline(**pipeline_args, generator=generator).images[0] args,
for _ in range(args.num_validation_images) accelerator,
] pipeline_args,
epoch,
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
) )
del pipeline
torch.cuda.empty_cache()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
...@@ -1733,44 +1751,20 @@ def main(args): ...@@ -1733,44 +1751,20 @@ def main(args):
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir) pipeline.load_lora_weights(args.output_dir)
# run inference # run inference
images = [] images = []
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
pipeline = pipeline.to(accelerator.device) pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = log_validation(
images = [ pipeline,
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] args,
for _ in range(args.num_validation_images) accelerator,
] pipeline_args,
epoch,
for tracker in accelerator.trackers: final_validation=True,
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
) )
if args.push_to_hub: if args.push_to_hub:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment