Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
6f5a0277
Commit
6f5a0277
authored
Mar 05, 2024
by
Yoach Lacombe
Browse files
add warnings for broken resume from + fix eval
parent
98e1fe31
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
3 deletions
+11
-3
run_stable_speech_training.py
run_stable_speech_training.py
+11
-3
No files found.
run_stable_speech_training.py
View file @
6f5a0277
...
@@ -1279,10 +1279,18 @@ def main():
...
@@ -1279,10 +1279,18 @@ def main():
for
epoch
in
range
(
0
,
epochs_trained
):
for
epoch
in
range
(
0
,
epochs_trained
):
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
if
training_args
.
max_steps
<
0
:
if
training_args
.
max_steps
<
0
:
# we know exactly the number of steps per epoch, so can skip through the required number of batches
# we know exactly the number of steps per epoch, so can skip through the required number of batches
resume_step
=
(
cur_step
-
epochs_trained
*
steps_per_epoch
)
*
gradient_accumulation_steps
resume_step
=
(
cur_step
-
epochs_trained
*
steps_per_epoch
)
# TODO: currently broken
if
resume_step
==
round
(
len
(
vectorized_datasets
[
"train"
])
/
train_batch_size
):
resume_step
=
None
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
epochs_trained
+=
1
else
:
else
:
# Currently we don't know how many steps we've taken in the current epoch
# Currently we don't know how many steps we've taken in the current epoch
# So we just shuffle the dataset one extra time and start from a fresh epoch
# So we just shuffle the dataset one extra time and start from a fresh epoch
...
@@ -1412,7 +1420,7 @@ def main():
...
@@ -1412,7 +1420,7 @@ def main():
vectorized_datasets
[
"eval"
],
vectorized_datasets
[
"eval"
],
collate_fn
=
data_collator
,
collate_fn
=
data_collator
,
batch_size
=
per_device_eval_batch_size
,
batch_size
=
per_device_eval_batch_size
,
drop_last
=
Fals
e
,
drop_last
=
Tru
e
,
num_workers
=
training_args
.
dataloader_pin_memory
,
num_workers
=
training_args
.
dataloader_pin_memory
,
pin_memory
=
training_args
.
dataloader_pin_memory
,
pin_memory
=
training_args
.
dataloader_pin_memory
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment