Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
3511a962
Unverified
Commit
3511a962
authored
May 30, 2024
by
Genius Patrick
Committed by
GitHub
May 30, 2024
Browse files
fix(training): lr scheduler doesn't work properly in distributed scenarios (#8312)
parent
42cae93b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
7 deletions
+18
-7
examples/text_to_image/train_text_to_image_lora.py
examples/text_to_image/train_text_to_image_lora.py
+18
-7
No files found.
examples/text_to_image/train_text_to_image_lora.py
View file @
3511a962
...
...
@@ -697,17 +697,22 @@ def main():
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps
=
False
num_
update
_steps_
per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_
warmup
_steps_
for_scheduler
=
args
.
lr_warmup_steps
*
accelerator
.
num_processes
if
args
.
max_train_steps
is
None
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
overrode_max_train_steps
=
True
len_train_dataloader_after_sharding
=
math
.
ceil
(
len
(
train_dataloader
)
/
accelerator
.
num_processes
)
num_update_steps_per_epoch
=
math
.
ceil
(
len_train_dataloader_after_sharding
/
args
.
gradient_accumulation_steps
)
num_training_steps_for_scheduler
=
(
args
.
num_train_epochs
*
num_update_steps_per_epoch
*
accelerator
.
num_processes
)
else
:
num_training_steps_for_scheduler
=
args
.
max_train_steps
*
accelerator
.
num_processes
lr_scheduler
=
get_scheduler
(
args
.
lr_scheduler
,
optimizer
=
optimizer
,
num_warmup_steps
=
args
.
lr
_warmup_steps
*
accelerator
.
num_processes
,
num_training_steps
=
args
.
max
_train_steps
*
accelerator
.
num_processes
,
num_warmup_steps
=
num
_warmup_steps
_for_scheduler
,
num_training_steps
=
num
_train
ing
_steps
_for_scheduler
,
)
# Prepare everything with our `accelerator`.
...
...
@@ -717,8 +722,14 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch
=
math
.
ceil
(
len
(
train_dataloader
)
/
args
.
gradient_accumulation_steps
)
if
overrode_
max_train_steps
:
if
args
.
max_train_steps
is
None
:
args
.
max_train_steps
=
args
.
num_train_epochs
*
num_update_steps_per_epoch
if
num_training_steps_for_scheduler
!=
args
.
max_train_steps
*
accelerator
.
num_processes
:
logger
.
warning
(
f
"The length of the 'train_dataloader' after 'accelerator.prepare' (
{
len
(
train_dataloader
)
}
) does not match "
f
"the expected length (
{
len_train_dataloader_after_sharding
}
) when the learning rate scheduler was created. "
f
"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args
.
num_train_epochs
=
math
.
ceil
(
args
.
max_train_steps
/
num_update_steps_per_epoch
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment