You need to sign in or sign up before continuing.
Unverified Commit f11b922b authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Modularize Dreambooth LoRA SD inferencing during and after training (#6654)



* modulize log validation

* run make style and refactor wanddb support

* remove redundant initialization

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 3dd4168d
...@@ -66,6 +66,9 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -66,6 +66,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.27.0.dev0") check_min_version("0.27.0.dev0")
...@@ -113,6 +116,71 @@ LoRA for the text encoder was enabled: {train_text_encoder}. ...@@ -113,6 +116,71 @@ LoRA for the text encoder was enabled: {train_text_encoder}.
model_card.save(os.path.join(repo_folder, "README.md")) model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
pipeline_args,
epoch,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
if args.validation_images is None:
images = []
for _ in range(args.num_validation_images):
with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, generator=generator).images[0]
images.append(image)
else:
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
phase_name: [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
return images
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained( text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path, pretrained_model_name_or_path,
...@@ -684,7 +752,6 @@ def main(args): ...@@ -684,7 +752,6 @@ def main(args):
if args.report_to == "wandb": if args.report_to == "wandb":
if not is_wandb_available(): if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.") raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
...@@ -1265,10 +1332,6 @@ def main(args): ...@@ -1265,10 +1332,6 @@ def main(args):
if accelerator.is_main_process: if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0: if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline # create pipeline
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
...@@ -1279,26 +1342,6 @@ def main(args): ...@@ -1279,26 +1342,6 @@ def main(args):
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, **scheduler_args
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
if args.pre_compute_text_embeddings: if args.pre_compute_text_embeddings:
pipeline_args = { pipeline_args = {
"prompt_embeds": validation_prompt_encoder_hidden_states, "prompt_embeds": validation_prompt_encoder_hidden_states,
...@@ -1307,36 +1350,13 @@ def main(args): ...@@ -1307,36 +1350,13 @@ def main(args):
else: else:
pipeline_args = {"prompt": args.validation_prompt} pipeline_args = {"prompt": args.validation_prompt}
if args.validation_images is None: images = log_validation(
images = [] pipeline,
for _ in range(args.num_validation_images): args,
with torch.cuda.amp.autocast(): accelerator,
image = pipeline(**pipeline_args, generator=generator).images[0] pipeline_args,
images.append(image) epoch,
else: )
images = []
for image in args.validation_images:
image = Image.open(image)
with torch.cuda.amp.autocast():
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
# Save the lora layers # Save the lora layers
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
...@@ -1364,46 +1384,21 @@ def main(args): ...@@ -1364,46 +1384,21 @@ def main(args):
args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype
) )
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
scheduler_args = {}
if "variance_type" in pipeline.scheduler.config:
variance_type = pipeline.scheduler.config.variance_type
if variance_type in ["learned", "learned_range"]:
variance_type = "fixed_small"
scheduler_args["variance_type"] = variance_type
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args)
pipeline = pipeline.to(accelerator.device)
# load attention processors # load attention processors
pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors")
# run inference # run inference
images = [] images = []
if args.validation_prompt and args.num_validation_images > 0: if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25}
images = [ images = log_validation(
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] pipeline,
for _ in range(args.num_validation_images) args,
] accelerator,
pipeline_args,
for tracker in accelerator.trackers: epoch,
if tracker.name == "tensorboard": is_final_validation=True,
np_images = np.stack([np.asarray(img) for img in images]) )
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment