Unverified Commit 37e9d695 authored by satani99's avatar satani99 Committed by GitHub
Browse files

Modularize instruct_pix2pix SD inferencing during and after training in examples (#7603)



* Modularize instruct_pix2pix code

* quality check

* quality check

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent a402431d
......@@ -53,6 +53,9 @@ from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")
......@@ -64,6 +67,48 @@ DATASET_NAME_MAPPING = {
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
def log_validation(
pipeline,
args,
accelerator,
generator,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
tracker.log({"validation": wandb_table})
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
......@@ -411,11 +456,6 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
......@@ -517,7 +557,8 @@ def main():
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
......@@ -923,11 +964,6 @@ def main():
and (args.validation_prompt is not None)
and (epoch % args.validation_epochs == 0)
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
......@@ -942,38 +978,14 @@ def main():
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})
log_validation(
pipeline,
args,
accelerator,
generator,
)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
......@@ -984,7 +996,6 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
......@@ -992,7 +1003,7 @@ def main():
args.pretrained_model_name_or_path,
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
unet=unet,
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
......@@ -1006,31 +1017,13 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
if args.validation_prompt is not None:
edited_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"test": wandb_table})
if (args.val_image_url is not None) and (args.validation_prompt is not None):
log_validation(
pipeline,
args,
accelerator,
generator,
)
accelerator.end_training()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment