Unverified Commit a9288b49 authored by SangKim's avatar SangKim Committed by GitHub
Browse files

Modularize InstructPix2Pix SDXL inferencing during and after training in examples (#6569)

parent c5441965
...@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available ...@@ -55,6 +55,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0") check_min_version("0.26.0.dev0")
...@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"] ...@@ -67,6 +70,57 @@ WANDB_TABLE_COL_NAMES = ["file_name", "edited_image", "edit_prompt"]
TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16} TORCH_DTYPE_MAPPING = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
def log_validation(
pipeline,
args,
accelerator,
generator,
global_step,
is_final_validation=False,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"):
edited_images = []
# Run inference
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
edited_images.append(a_val_img)
# Save validation images
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
logger_name = "test" if is_final_validation else "validation"
tracker.log({logger_name: wandb_table})
def import_model_class_from_model_name_or_path( def import_model_class_from_model_name_or_path(
pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
): ):
...@@ -447,11 +501,6 @@ def main(): ...@@ -447,11 +501,6 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging. # Make one log on every process with the configuration for debugging.
logging.basicConfig( logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
...@@ -1111,11 +1160,6 @@ def main(): ...@@ -1111,11 +1160,6 @@ def main():
### BEGIN: Perform validation every `validation_epochs` steps ### BEGIN: Perform validation every `validation_epochs` steps
if global_step % args.validation_steps == 0: if global_step % args.validation_steps == 0:
if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None): if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline # create pipeline
if args.use_ema: if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference. # Store the UNet parameters temporarily and load the EMA parameters to perform inference.
...@@ -1135,44 +1179,16 @@ def main(): ...@@ -1135,44 +1179,16 @@ def main():
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
) )
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
# Save validation images
val_save_dir = os.path.join(args.output_dir, "validation_images")
if not os.path.exists(val_save_dir):
os.makedirs(val_save_dir)
original_image = (
lambda image_url_or_path: load_image(image_url_or_path)
if urlparse(image_url_or_path).scheme
else Image.open(image_url_or_path).convert("RGB")
)(args.val_image_url_or_path)
with torch.autocast(
str(accelerator.device).replace(":0", ""), enabled=accelerator.mixed_precision == "fp16"
):
edited_images = []
for val_img_idx in range(args.num_validation_images):
a_val_img = pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
edited_images.append(a_val_img)
a_val_img.save(os.path.join(val_save_dir, f"step_{global_step}_val_img_{val_img_idx}.png"))
for tracker in accelerator.trackers: log_validation(
if tracker.name == "wandb": pipeline,
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES) args,
for edited_image in edited_images: accelerator,
wandb_table.add_data( generator,
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt global_step,
is_final_validation=False,
) )
tracker.log({"validation": wandb_table})
if args.use_ema: if args.use_ema:
# Switch back to the original UNet parameters. # Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters()) ema_unet.restore(unet.parameters())
...@@ -1187,7 +1203,6 @@ def main(): ...@@ -1187,7 +1203,6 @@ def main():
# Create the pipeline using the trained modules and save it. # Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = unwrap_model(unet)
if args.use_ema: if args.use_ema:
ema_unet.copy_to(unet.parameters()) ema_unet.copy_to(unet.parameters())
...@@ -1198,10 +1213,11 @@ def main(): ...@@ -1198,10 +1213,11 @@ def main():
tokenizer=tokenizer_1, tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
vae=vae, vae=vae,
unet=unet, unet=unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
) )
pipeline.save_pretrained(args.output_dir) pipeline.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
...@@ -1212,30 +1228,15 @@ def main(): ...@@ -1212,30 +1228,15 @@ def main():
ignore_patterns=["step_*", "epoch_*"], ignore_patterns=["step_*", "epoch_*"],
) )
if args.validation_prompt is not None: if (args.val_image_url_or_path is not None) and (args.validation_prompt is not None):
edited_images = [] log_validation(
pipeline = pipeline.to(accelerator.device) pipeline,
with torch.autocast(str(accelerator.device).replace(":0", "")): args,
for _ in range(args.num_validation_images): accelerator,
edited_images.append( generator,
pipeline( global_step,
args.validation_prompt, is_final_validation=True,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
) )
tracker.log({"test": wandb_table})
accelerator.end_training() accelerator.end_training()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment