Unverified Commit e44b205e authored by Charchit Sharma's avatar Charchit Sharma Committed by GitHub
Browse files

Make ControlNet SDXL Training Script torch.compile compatible (#6526)

* make torch.compile compatible

* fix quality
parent 60cb4432
...@@ -52,6 +52,7 @@ from diffusers import ( ...@@ -52,6 +52,7 @@ from diffusers import (
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available(): if is_wandb_available():
...@@ -847,6 +848,11 @@ def main(args): ...@@ -847,6 +848,11 @@ def main(args):
logger.info("Initializing controlnet weights from unet") logger.info("Initializing controlnet weights from unet")
controlnet = ControlNetModel.from_unet(unet) controlnet = ControlNetModel.from_unet(unet)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# `accelerate` 0.16.0 will have better support for customized saving # `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
...@@ -908,9 +914,9 @@ def main(args): ...@@ -908,9 +914,9 @@ def main(args):
" doing mixed precision training, copy of the weights should still be float32." " doing mixed precision training, copy of the weights should still be float32."
) )
if accelerator.unwrap_model(controlnet).dtype != torch.float32: if unwrap_model(controlnet).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
...@@ -1158,7 +1164,8 @@ def main(args): ...@@ -1158,7 +1164,8 @@ def main(args):
sample.to(dtype=weight_dtype) for sample in down_block_res_samples sample.to(dtype=weight_dtype) for sample in down_block_res_samples
], ],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
).sample return_dict=False,
)[0]
# Get the target for loss depending on the prediction type # Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon": if noise_scheduler.config.prediction_type == "epsilon":
...@@ -1223,7 +1230,7 @@ def main(args): ...@@ -1223,7 +1230,7 @@ def main(args):
# Create the pipeline using using the trained modules and save it. # Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
controlnet = accelerator.unwrap_model(controlnet) controlnet = unwrap_model(controlnet)
controlnet.save_pretrained(args.output_dir) controlnet.save_pretrained(args.output_dir)
if args.push_to_hub: if args.push_to_hub:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment