Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
1652c372
Commit
1652c372
authored
Mar 08, 2024
by
Yoach Lacombe
Browse files
fix save/load accelerator state
parent
daca5721
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
10 deletions
+4
-10
run_stable_speech_training.py
run_stable_speech_training.py
+4
-10
No files found.
run_stable_speech_training.py
View file @
1652c372
...
...
@@ -1298,15 +1298,7 @@ def main():
if
training_args
.
max_steps
<
0
:
# we know exactly the number of steps per epoch, so can skip through the required number of batches
resume_step
=
(
cur_step
-
epochs_trained
*
steps_per_epoch
)
# TODO: currently broken
if
resume_step
==
round
(
len
(
vectorized_datasets
[
"train"
])
/
train_batch_size
):
resume_step
=
None
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
epochs_trained
+=
1
resume_step
=
(
cur_step
-
epochs_trained
*
steps_per_epoch
)
*
gradient_accumulation_steps
else
:
# Currently we don't know how many steps we've taken in the current epoch
# So we just shuffle the dataset one extra time and start from a fresh epoch
...
...
@@ -1409,7 +1401,9 @@ def main():
# save checkpoint and weights after each save_steps and at the end of training
if
(
cur_step
%
training_args
.
save_steps
==
0
)
or
cur_step
==
total_train_steps
:
intermediate_dir
=
os
.
path
.
join
(
training_args
.
output_dir
,
f
"checkpoint-
{
cur_step
}
-epoch-
{
epoch
}
"
)
accelerator
.
save_state
(
output_dir
=
intermediate_dir
)
# safe_serialization=False to avoid shared tensors saving issue (TODO: it's a temporary fix)
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment