Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
946d1cb2
Unverified
Commit
946d1cb2
authored
Jan 25, 2023
by
Suraj Patil
Committed by
GitHub
Jan 25, 2023
Browse files
[dreambooth] check the low-precision guard before preparing model (#2102)
check the dtype before preparing model
parent
09779cbb
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
16 deletions
+17
-16
examples/dreambooth/train_dreambooth.py
examples/dreambooth/train_dreambooth.py
+17
-16
No files found.
examples/dreambooth/train_dreambooth.py
View file @
946d1cb2
...
...
@@ -624,6 +624,23 @@ def main(args):
if
args
.
train_text_encoder
:
text_encoder
.
gradient_checkpointing_enable
()
# Check that all trainable models are in full precision
low_precision_error_string
=
(
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training. copy of the weights should still be float32."
)
if
accelerator
.
unwrap_model
(
unet
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
f
"Unet loaded as datatype
{
accelerator
.
unwrap_model
(
unet
).
dtype
}
.
{
low_precision_error_string
}
"
)
if
args
.
train_text_encoder
and
accelerator
.
unwrap_model
(
text_encoder
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
f
"Text encoder loaded as datatype
{
accelerator
.
unwrap_model
(
text_encoder
).
dtype
}
."
f
"
{
low_precision_error_string
}
"
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if
args
.
allow_tf32
:
...
...
@@ -717,22 +734,6 @@ def main(args):
if
not
args
.
train_text_encoder
:
text_encoder
.
to
(
accelerator
.
device
,
dtype
=
weight_dtype
)
low_precision_error_string
=
(
"Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training. copy of the weights should still be float32."
)
if
accelerator
.
unwrap_model
(
unet
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
f
"Unet loaded as datatype
{
accelerator
.
unwrap_model
(
unet
).
dtype
}
.
{
low_precision_error_string
}
"
)
if
args
.
train_text_encoder
and
accelerator
.
unwrap_model
(
text_encoder
).
dtype
!=
torch
.
float32
:
raise
ValueError
(
f
"Text encoder loaded as datatype
{
accelerator
.
unwrap_model
(
text_encoder
).
dtype
}
."
f
"
{
low_precision_error_string
}
"
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
if
overrode_max_train_steps
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment