Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
75ae54a8
Commit
75ae54a8
authored
Mar 29, 2024
by
yoach@huggingface.co
Browse files
improve pre-processing logics
parent
5e2041eb
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
18 deletions
+24
-18
run_stable_speech_training.py
run_stable_speech_training.py
+24
-18
No files found.
run_stable_speech_training.py
View file @
75ae54a8
...
...
@@ -1039,7 +1039,14 @@ def main():
# Freeze Encoders
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor
=
torch
.
tensor
([
accelerator
.
process_index
],
device
=
accelerator
.
device
)
gathered_tensor
=
accelerator
.
gather
(
test_tensor
)
print
(
"gathered_tensor"
,
gathered_tensor
)
accelerator
.
wait_for_everyone
()
if
not
dataset_was_precomputed
:
# Filter on text length
if
description_column_name
is
not
None
:
...
...
@@ -1158,7 +1165,6 @@ def main():
data_loader
=
accelerator
.
prepare
(
data_loader
)
all_generated_labels
=
[]
all_ratios
=
[]
all_lens
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
generate_labels
=
apply_audio_decoder
(
batch
)
...
...
@@ -1166,30 +1172,31 @@ def main():
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
if
accelerator
.
is_main_process
:
all_generated_labels
.
extend
(
generate_labels
[
"labels"
].
cpu
())
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
().
squeeze
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
().
squeeze
())
lab
=
generate_labels
[
"labels"
].
cpu
().
transpose
(
1
,
2
).
to
(
torch
.
int16
)
rat
=
generate_labels
[
"ratio"
].
cpu
().
squeeze
()
lens
=
generate_labels
[
"len_audio"
].
cpu
().
squeeze
()
lab
=
[
l
[:,
:
int
(
ratio
*
length
)]
for
(
l
,
ratio
,
length
)
in
zip
(
lab
,
rat
,
lens
)]
all_generated_labels
.
extend
(
lab
)
all_lens
.
extend
(
lens
)
# (1, codebooks, seq_len) where seq_len=1
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
if
accelerator
.
is_main_process
:
tmp_labels
=
Dataset
.
from_dict
({
"labels"
:
all_generated_labels
,
"ratios"
:
all_ratios
,
"target_length"
:
all_lens
})
tmp_labels
.
save_to_disk
(
data_args
.
temporary_save_to_disk
,
num_proc
=
data_args
.
preprocessing_num_workers
)
tmp_labels
=
Dataset
.
from_dict
({
"labels"
:
all_generated_labels
,
"target_length"
:
all_lens
})
tmp_labels
.
save_to_disk
(
os
.
path
.
join
(
data_args
.
temporary_save_to_disk
,
split
),
num_proc
=
1
if
split
==
"eval"
else
data_args
.
preprocessing_num_workers
)
accelerator
.
wait_for_everyone
()
del
all_generated_labels
tmp_labels
=
datasets
.
load_from_disk
(
data_args
.
temporary_save_to_disk
)
tmp_labels
=
datasets
.
load_from_disk
(
os
.
path
.
join
(
data_args
.
temporary_save_to_disk
,
split
)
)
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
concatenate_datasets
([
vectorized_datasets
[
split
],
tmp_labels
],
axis
=
1
)
def
postprocess_dataset
(
labels
,
target_length
,
ratio
):
def
postprocess_dataset
(
labels
):
# (1, codebooks, seq_len)
labels
=
torch
.
tensor
(
labels
).
transpose
(
0
,
1
).
unsqueeze
(
0
)
len_
=
int
(
ratio
*
target_length
)
labels
=
labels
[:,
:,
:
len_
]
labels
=
torch
.
tensor
(
labels
).
unsqueeze
(
0
)
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
...
...
@@ -1210,7 +1217,7 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
output
=
{
"labels"
:
labels
[:,
1
:]
.
cpu
()
}
output
=
{
"labels"
:
labels
[:,
1
:]}
return
output
...
...
@@ -1219,14 +1226,13 @@ def main():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
data_args
.
preprocessing_num_workers
,
# this one is resource consuming if many processor.
input_columns
=
[
"labels"
,
"target_length"
,
"ratios"
],
remove_columns
=
[
"ratios"
],
input_columns
=
[
"labels"
],
desc
=
"Postprocessing labeling"
,
)
accelerator
.
free_memory
()
del
generate_labels
,
all_lens
,
all_ratios
del
generate_labels
,
all_lens
with
accelerator
.
main_process_first
():
...
...
@@ -1242,7 +1248,7 @@ def main():
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
if
accelerator
.
is_main_process
:
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
,
num_proc
=
data_args
.
preprocessing_num_workers
)
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
,
num_proc
=
min
(
data_args
.
preprocessing_num_workers
,
len
(
vectorized_datasets
[
"eval"
])
-
1
)
)
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
# for large datasets it is advised to run the preprocessing on a
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment