"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9141c1f9d5d8c40ce73c7d72a2fd5aedab8c031d"
Unverified Commit 946d1cb2 authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

[dreambooth] check the low-precision guard before preparing model (#2102)

check the dtype before preparing model
parent 09779cbb
...@@ -624,6 +624,23 @@ def main(args): ...@@ -624,6 +624,23 @@ def main(args):
if args.train_text_encoder: if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable() text_encoder.gradient_checkpointing_enable()
# Check that all trainable models are in full precision
low_precision_error_string = (
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training. copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
)
# Enable TF32 for faster training on Ampere GPUs, # Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32: if args.allow_tf32:
...@@ -717,22 +734,6 @@ def main(args): ...@@ -717,22 +734,6 @@ def main(args):
if not args.train_text_encoder: if not args.train_text_encoder:
text_encoder.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype)
low_precision_error_string = (
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training. copy of the weights should still be float32."
)
if accelerator.unwrap_model(unet).dtype != torch.float32:
raise ValueError(
f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
)
if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
raise ValueError(
f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
f" {low_precision_error_string}"
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps: if overrode_max_train_steps:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment