"...git@developer.sourcefind.cn:OpenDAS/torch-harmonics.git" did not exist on "77a64b2c77b35f7af7923eacded0c94daf2bb319"
Unverified Commit b053053a authored by Charchit Sharma's avatar Charchit Sharma Committed by GitHub
Browse files

Make InstructPix2Pix Training Script torch.compile compatible (#6558)

* added torch.compile for pix2pix

* required changes
parent 08702fc1
...@@ -49,6 +49,7 @@ from diffusers.optimization import get_scheduler ...@@ -49,6 +49,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
...@@ -489,6 +490,11 @@ def main(): ...@@ -489,6 +490,11 @@ def main():
else: else:
raise ValueError("xformers is not available. Make sure it is installed correctly") raise ValueError("xformers is not available. Make sure it is installed correctly")
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# `accelerate` 0.16.0 will have better support for customized saving # `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"): if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
...@@ -845,7 +851,7 @@ def main(): ...@@ -845,7 +851,7 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training). # Gather the losses across all processes for logging (if we use distributed training).
...@@ -919,9 +925,9 @@ def main(): ...@@ -919,9 +925,9 @@ def main():
# The models need unwrapping because for compatibility in distributed training mode. # The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae), vae=unwrap_model(vae),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -965,14 +971,14 @@ def main(): ...@@ -965,14 +971,14 @@ def main():
# Create the pipeline using the trained modules and save it. # Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if accelerator.is_main_process: if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet) unet = unwrap_model(unet)
if args.use_ema: if args.use_ema:
ema_unet.copy_to(unet.parameters()) ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained( pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder), text_encoder=unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae), vae=unwrap_model(vae),
unet=unet, unet=unet,
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment