Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
faef1c72
Commit
faef1c72
authored
Mar 05, 2024
by
Yoach Lacombe
Browse files
fix sampling + free gpu memory after eval
parent
03611f97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
4 deletions
+12
-4
run_stable_speech_training.py
run_stable_speech_training.py
+12
-4
No files found.
run_stable_speech_training.py
View file @
faef1c72
...
...
@@ -1202,11 +1202,9 @@ def main():
# Prepare everything with accelerate
model
,
optimizer
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
lr_scheduler
)
sampler
=
LengthGroupedSampler
(
per_device_train_batch_size
,
lengths
=
vectorized_datasets
[
"train"
][
"target_length"
])
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
f
" Num examples =
{
total_train_steps
*
train_batch_size
*
gradient_accumulation_steps
}
"
)
logger
.
info
(
" Instantaneous batch size per device ="
f
"
{
training_args
.
per_device_train_batch_size
}
"
)
logger
.
info
(
" Instantaneous batch size per device ="
f
"
{
per_device_train_batch_size
}
"
)
logger
.
info
(
" Gradient accumulation steps ="
f
"
{
gradient_accumulation_steps
}
"
)
logger
.
info
(
f
" Total train batch size (w. parallel & distributed) =
{
train_batch_size
*
gradient_accumulation_steps
}
"
...
...
@@ -1341,6 +1339,8 @@ def main():
for
epoch
in
range
(
epochs_trained
,
num_epochs
):
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
# TODO: add args
sampler
=
LengthGroupedSampler
(
train_batch_size
,
lengths
=
vectorized_datasets
[
"train"
][
"target_length"
])
train_dataloader
=
DataLoader
(
vectorized_datasets
[
"train"
],
collate_fn
=
data_collator
,
...
...
@@ -1450,7 +1450,7 @@ def main():
# TODO: also add prompt ids
# TODO: better gather
generated_audios
,
input_ids
,
prompts
=
accelerator
.
pad_across_processes
((
generated_audios
,
batch
[
"input_ids"
],
batch
[
"prompt_input_ids"
]),
dim
=
1
,
pad_index
=
0
)
generated_audios
,
input_ids
,
prompts
=
accelerator
.
gather_for_metrics
((
generated_audios
,
input_ids
,
prompts
))
generated_audios
,
input_ids
,
prompts
=
accelerator
.
gather_for_metrics
((
generated_audios
,
input_ids
,
prompts
))
eval_preds
.
extend
(
generated_audios
)
eval_descriptions
.
extend
(
input_ids
)
eval_prompts
.
extend
(
prompts
)
...
...
@@ -1494,6 +1494,14 @@ def main():
epoch
=
epoch
,
prefix
=
"eval"
,
)
# release eval batch and relax metrics
eval_metrics
=
[]
eval_preds
=
[]
eval_descriptions
=
[]
eval_prompts
=
[]
batch
=
release_memory
(
batch
)
# flush the train metrics
train_start
=
time
.
time
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment