Unverified Commit 7d631825 authored by Vinh H. Pham's avatar Vinh H. Pham Committed by GitHub
Browse files

Make Dreambooth SD Training Script `torch.compile` compatible (#6532)



* support compile

* make style

* move unwrap_model inside function

* change unwrap call

* run make style

* Update examples/dreambooth/train_dreambooth.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Revert "Update examples/dreambooth/train_dreambooth.py"

This reverts commit 70ab09732e7cfec0b19c497f823ddd1c8259dad0.

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 33d2b5b0
...@@ -55,6 +55,7 @@ from diffusers.optimization import get_scheduler ...@@ -55,6 +55,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available(): if is_wandb_available():
...@@ -129,15 +130,12 @@ def log_validation( ...@@ -129,15 +130,12 @@ def log_validation(
if vae is not None: if vae is not None:
pipeline_args["vae"] = vae pipeline_args["vae"] = vae
if text_encoder is not None:
text_encoder = accelerator.unwrap_model(text_encoder)
# create pipeline (note: unet and vae are loaded again in float32) # create pipeline (note: unet and vae are loaded again in float32)
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
unet=accelerator.unwrap_model(unet), unet=unet,
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
torch_dtype=weight_dtype, torch_dtype=weight_dtype,
...@@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte ...@@ -794,6 +792,7 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
prompt_embeds = text_encoder( prompt_embeds = text_encoder(
text_input_ids, text_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
return_dict=False,
) )
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
...@@ -931,11 +930,16 @@ def main(args): ...@@ -931,11 +930,16 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
) )
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir): def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process: if accelerator.is_main_process:
for model in models: for model in models:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" sub_dir = "unet" if isinstance(model, type(unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir)) model.save_pretrained(os.path.join(output_dir, sub_dir))
# make sure to pop weight so that corresponding model is not saved again # make sure to pop weight so that corresponding model is not saved again
...@@ -946,7 +950,7 @@ def main(args): ...@@ -946,7 +950,7 @@ def main(args):
# pop models so that they are not loaded again # pop models so that they are not loaded again
model = models.pop() model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(text_encoder))): if isinstance(model, type(unwrap_model(text_encoder))):
# load transformers style into model # load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config model.config = load_model.config
...@@ -991,15 +995,12 @@ def main(args): ...@@ -991,15 +995,12 @@ def main(args):
" doing mixed precision training. copy of the weights should still be float32." " doing mixed precision training. copy of the weights should still be float32."
) )
if accelerator.unwrap_model(unet).dtype != torch.float32: if unwrap_model(unet).dtype != torch.float32:
raise ValueError( raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError( raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}"
f" {low_precision_error_string}"
) )
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
...@@ -1246,7 +1247,7 @@ def main(args): ...@@ -1246,7 +1247,7 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask, text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
) )
if accelerator.unwrap_model(unet).config.in_channels == channels * 2: if unwrap_model(unet).config.in_channels == channels * 2:
noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1) noisy_model_input = torch.cat([noisy_model_input, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps": if args.class_labels_conditioning == "timesteps":
...@@ -1256,8 +1257,8 @@ def main(args): ...@@ -1256,8 +1257,8 @@ def main(args):
# Predict the noise residual # Predict the noise residual
model_pred = unet( model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels, return_dict=False
).sample )[0]
if model_pred.shape[1] == 6: if model_pred.shape[1] == 6:
model_pred, _ = torch.chunk(model_pred, 2, dim=1) model_pred, _ = torch.chunk(model_pred, 2, dim=1)
...@@ -1350,9 +1351,9 @@ def main(args): ...@@ -1350,9 +1351,9 @@ def main(args):
if args.validation_prompt is not None and global_step % args.validation_steps == 0: if args.validation_prompt is not None and global_step % args.validation_steps == 0:
images = log_validation( images = log_validation(
text_encoder, unwrap_model(text_encoder) if text_encoder is not None else text_encoder,
tokenizer, tokenizer,
unet, unwrap_model(unet),
vae, vae,
args, args,
accelerator, accelerator,
...@@ -1375,14 +1376,14 @@ def main(args): ...@@ -1375,14 +1376,14 @@ def main(args):
pipeline_args = {} pipeline_args = {}
if text_encoder is not None: if text_encoder is not None:
pipeline_args["text_encoder"] = accelerator.unwrap_model(text_encoder) pipeline_args["text_encoder"] = unwrap_model(text_encoder)
if args.skip_save_text_encoder: if args.skip_save_text_encoder:
pipeline_args["text_encoder"] = None pipeline_args["text_encoder"] = None
pipeline = DiffusionPipeline.from_pretrained( pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet), unet=unwrap_model(unet),
revision=args.revision, revision=args.revision,
variant=args.variant, variant=args.variant,
**pipeline_args, **pipeline_args,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment