"...text-generation-inference.git" did not exist on "d503e8f09d23821239b3845974f6e56f013d3d2c"
Unverified Commit 7081a256 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Examples] Multiple enhancements to the ControlNet training scripts (#7096)



* log_validation unification for controlnet.

* additional fixes.

* remove print.

* better reuse and loading

* make final inference run conditional.

* Update examples/controlnet/README_sdxl.md
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* resize the control image in the snippet.

---------
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent 848f9fe6
...@@ -113,7 +113,7 @@ pipe.enable_xformers_memory_efficient_attention() ...@@ -113,7 +113,7 @@ pipe.enable_xformers_memory_efficient_attention()
# memory optimization. # memory optimization.
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload()
control_image = load_image("./conditioning_image_1.png") control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
prompt = "pale golden rod circle with old lace background" prompt = "pale golden rod circle with old lace background"
# generate image # generate image
...@@ -128,4 +128,14 @@ image.save("./output.png") ...@@ -128,4 +128,14 @@ image.save("./output.png")
### Specifying a better VAE ### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:
```diff
+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
+ vae=vae,
)
...@@ -14,6 +14,8 @@ ...@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import contextlib
import gc
import logging import logging
import math import math
import os import os
...@@ -74,10 +76,15 @@ def image_grid(imgs, rows, cols): ...@@ -74,10 +76,15 @@ def image_grid(imgs, rows, cols):
return grid return grid
def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step): def log_validation(
vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
):
logger.info("Running validation... ") logger.info("Running validation... ")
controlnet = accelerator.unwrap_model(controlnet) if not is_final_validation:
controlnet = accelerator.unwrap_model(controlnet)
else:
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
pipeline = StableDiffusionControlNetPipeline.from_pretrained( pipeline = StableDiffusionControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
...@@ -118,6 +125,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler ...@@ -118,6 +125,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
) )
image_logs = [] image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
for validation_prompt, validation_image in zip(validation_prompts, validation_images): for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB") validation_image = Image.open(validation_image).convert("RGB")
...@@ -125,7 +133,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler ...@@ -125,7 +133,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with torch.autocast("cuda"): with inference_ctx:
image = pipeline( image = pipeline(
validation_prompt, validation_image, num_inference_steps=20, generator=generator validation_prompt, validation_image, num_inference_steps=20, generator=generator
).images[0] ).images[0]
...@@ -136,6 +144,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler ...@@ -136,6 +144,7 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
) )
tracker_key = "test" if is_final_validation else "validation"
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
for log in image_logs: for log in image_logs:
...@@ -167,10 +176,14 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler ...@@ -167,10 +176,14 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler
image = wandb.Image(image, caption=validation_prompt) image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image) formatted_images.append(image)
tracker.log({"validation": formatted_images}) tracker.log({tracker_key: formatted_images})
else: else:
logger.warn(f"image logging not implemented for {tracker.name}") logger.warn(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
torch.cuda.empty_cache()
return image_logs return image_logs
...@@ -197,7 +210,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st ...@@ -197,7 +210,7 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = "" img_str = ""
if image_logs is not None: if image_logs is not None:
img_str = "You can find some example images below.\n" img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs): for i, log in enumerate(image_logs):
images = log["images"] images = log["images"]
validation_prompt = log["validation_prompt"] validation_prompt = log["validation_prompt"]
...@@ -1131,6 +1144,22 @@ def main(args): ...@@ -1131,6 +1144,22 @@ def main(args):
controlnet = unwrap_model(controlnet) controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir) controlnet.save_pretrained(args.output_dir)
# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
controlnet=None,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
is_final_validation=True,
)
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
repo_id, repo_id,
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import argparse import argparse
import contextlib
import functools import functools
import gc import gc
import logging import logging
...@@ -65,20 +66,38 @@ check_min_version("0.27.0.dev0") ...@@ -65,20 +66,38 @@ check_min_version("0.27.0.dev0")
logger = get_logger(__name__) logger = get_logger(__name__)
def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step): def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step, is_final_validation=False):
logger.info("Running validation... ") logger.info("Running validation... ")
controlnet = accelerator.unwrap_model(controlnet) if not is_final_validation:
controlnet = accelerator.unwrap_model(controlnet)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
else:
controlnet = ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is not None:
vae = AutoencoderKL.from_pretrained(args.pretrained_vae_model_name_or_path, torch_dtype=weight_dtype)
else:
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", torch_dtype=weight_dtype
)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLControlNetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unet,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device) pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True) pipeline.set_progress_bar_config(disable=True)
...@@ -106,6 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) ...@@ -106,6 +125,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
) )
image_logs = [] image_logs = []
inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast("cuda")
for validation_prompt, validation_image in zip(validation_prompts, validation_images): for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB") validation_image = Image.open(validation_image).convert("RGB")
...@@ -114,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) ...@@ -114,7 +134,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
images = [] images = []
for _ in range(args.num_validation_images): for _ in range(args.num_validation_images):
with torch.autocast("cuda"): with inference_ctx:
image = pipeline( image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0] ).images[0]
...@@ -124,6 +144,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) ...@@ -124,6 +144,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
) )
tracker_key = "test" if is_final_validation else "validation"
for tracker in accelerator.trackers: for tracker in accelerator.trackers:
if tracker.name == "tensorboard": if tracker.name == "tensorboard":
for log in image_logs: for log in image_logs:
...@@ -155,7 +176,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step) ...@@ -155,7 +176,7 @@ def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step)
image = wandb.Image(image, caption=validation_prompt) image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image) formatted_images.append(image)
tracker.log({"validation": formatted_images}) tracker.log({tracker_key: formatted_images})
else: else:
logger.warn(f"image logging not implemented for {tracker.name}") logger.warn(f"image logging not implemented for {tracker.name}")
...@@ -189,7 +210,7 @@ def import_model_class_from_model_name_or_path( ...@@ -189,7 +210,7 @@ def import_model_class_from_model_name_or_path(
def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None):
img_str = "" img_str = ""
if image_logs is not None: if image_logs is not None:
img_str = "You can find some example images below.\n" img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs): for i, log in enumerate(image_logs):
images = log["images"] images = log["images"]
validation_prompt = log["validation_prompt"] validation_prompt = log["validation_prompt"]
...@@ -1228,7 +1249,13 @@ def main(args): ...@@ -1228,7 +1249,13 @@ def main(args):
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
image_logs = log_validation( image_logs = log_validation(
vae, unet, controlnet, args, accelerator, weight_dtype, global_step vae=vae,
unet=unet,
controlnet=controlnet,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
) )
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
...@@ -1244,6 +1271,21 @@ def main(args): ...@@ -1244,6 +1271,21 @@ def main(args):
controlnet = unwrap_model(controlnet) controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir) controlnet.save_pretrained(args.output_dir)
# Run a final round of validation.
# Setting `vae`, `unet`, and `controlnet` to None to load automatically from `args.output_dir`.
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae=None,
unet=None,
controlnet=None,
args=args,
accelerator=accelerator,
weight_dtype=weight_dtype,
step=global_step,
is_final_validation=True,
)
if args.push_to_hub: if args.push_to_hub:
save_model_card( save_model_card(
repo_id, repo_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment