Unverified Commit 7081a256 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] Multiple enhancements to the ControlNet training scripts (#7096)



* log_validation unification for controlnet.

* additional fixes.

* remove print.

* better reuse and loading

* make final inference run conditional.

* Update examples/controlnet/README_sdxl.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* resize the control image in the snippet.

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 848f9fe6
......@@ -113,7 +113,7 @@ pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()
control_image = load_image("./conditioning_image_1.png")
control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
prompt = "pale golden rod circle with old lace background"
# generate image
......@@ -128,4 +128,14 @@ image.save("./output.png")
### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:
```diff
+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
+ vae=vae,
)
......@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
import argparse
import contextlib
import gc
import logging
import math
import os
......@@ -74,10 +76,15 @@ def image_grid(imgs, rows, cols):
return grid
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step):
def log_validation(
vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
):
logger.info("Running validation... ")
controlnet = accelerator.unwrap_model(controlnet)
if not is_final_validation:
controlnet = accelerator.unwrap_model(controlnet)
else:
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
......@@ -118,6 +125,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
)
image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
......@@ -125,7 +133,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
with inference_ctx:
image = pipeline(
validation_prompt, validation_image, num_inference_steps=20, generator=generator
).images[0]
......@@ -136,6 +144,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
)
tracker_key = "test" if is_final_validation else "validation"
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
......@@ -167,10 +176,14 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
tracker.log({"validation": formatted_images})
tracker.log({tracker_key: formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
torch.cuda.empty_cache()
return image_logs
......@@ -197,7 +210,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n"
img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
......@@ -1131,6 +1144,22 @@ def main(args):
controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir)
# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=None,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
is_final_validation=True,
)
if args.push_to_hub:
save_model_card(
repo_id,
......
......@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
import argparse
import contextlib
import functools
import gc
import logging
......@@ -65,20 +66,38 @@ check_min_version("0.27.0.dev0")
logger = get_logger(__name__)
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step):
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
logger.info("Running validation... ")
controlnet = accelerator.unwrap_model(controlnet)
if not is_final_validation:
controlnet = accelerator.unwrap_model(controlnet)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
else:
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is not None:
vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
else:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
......@@ -106,6 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
)
image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
......@@ -114,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
images = []
for _ in range(args.num_validation_images):
with torch.autocast("cuda"):
with inference_ctx:
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
......@@ -124,6 +144,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
)
tracker_key = "test" if is_final_validation else "validation"
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
......@@ -155,7 +176,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
tracker.log({"validation": formatted_images})
tracker.log({tracker_key: formatted_images})
else:
logger.warn(f"image logging not implemented for {tracker.name}")
......@@ -189,7 +210,7 @@ def import_model_class_from_model_name_or_path(
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n"
img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
......@@ -1228,7 +1249,13 @@ def main(args):
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
image_logs = log_validation(
vae, unet, controlnet, args, accelerator, weight_dtype, global_step
vae=vae,
unet=unet,
controlnet=controlnet,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
......@@ -1244,6 +1271,21 @@ def main(args):
controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir)
# Run a final round of validation.
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae=None,
unet=None,
controlnet=None,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
is_final_validation=True,
)
if args.push_to_hub:
save_model_card(
repo_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment