Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
diffusers
Commits
3511a962
Unverified
Commit
3511a962
authored
May 30, 2024
by
Genius Patrick
Committed by
GitHub
May 30, 2024
Browse files
fix(training): lr scheduler doesn't work properly in distributed scenarios (#8312)
parent
42cae93b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
examples/text_to_image/train_text_to_image_lora.py
examples/text_to_image/train_text_to_image_lora.py
+18
-7
No files found.
examples/text_to_image/train_text_to_image_lora.py
View file @
3511a962
...
@@ -697,17 +697,22 @@ def main():
...
@@ -697,17 +697,22 @@ def main():
)
)
# Scheduler and math around the number of training steps.
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_
update
_steps_
per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
num_
warmup
_steps_
for_scheduler
=
args
.
lr_warmup_steps
*
accelerator
.
num_processes
if
args
.
max_train_steps
is
None
:
if
args
.
max_train_steps
is
None
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
len_train_dataloader_after_sharding
=
math
.
ceil
(
len
(
train_dataloader
)
/
accelerator
.
num_processes
)
overrode_max_train_steps
=
True
num_update_steps_per_epoch
=
math
.
ceil
(
len_train_dataloader_after_sharding
/
args
.
gradient_accumulation_steps
)
num_training_steps_for_scheduler
=
(
args
.
num_train_epochs
*
num_update_steps_per_epoch
*
accelerator
.
num_processes
)
else
:
num_training_steps_for_scheduler
=
args
.
max_train_steps
*
accelerator
.
num_processes
lr_scheduler
=
get_scheduler
(
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
args
.
lr_scheduler
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr
_warmup_steps
*
accelerator
.
num_processes
,
num_warmup_steps
=
num
_warmup_steps
_for_scheduler
,
num_training_steps
=
args
.
max
_train_steps
*
accelerator
.
num_processes
,
num_training_steps
=
num
_train
ing
_steps
_for_scheduler
,
)
)
# Prepare everything with our `accelerator`.
# Prepare everything with our `accelerator`.
...
@@ -717,8 +722,14 @@ def main():
...
@@ -717,8 +722,14 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
if
overrode_
max_train_steps
:
if
args
.
max_train_steps
is
None
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
if
num_training_steps_for_scheduler
!=
args
.
max_train_steps
*
accelerator
.
num_processes
:
logger
.
warning
(
f
"The length of the 'train_dataloader' after 'accelerator.prepare' (
{
len
(
train_dataloader
)
}
) does not match "
f
"the expected length (
{
len_train_dataloader_after_sharding
}
) when the learning rate scheduler was created. "
f
"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
# Afterwards we recalculate our number of training epochs
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment