Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
b09eba24
Commit
b09eba24
authored
Mar 26, 2024
by
yoach@huggingface.co
Browse files
compute audio in collator instead of previously
parent
441af9a4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
14 deletions
+22
-14
run_stable_speech_training.py
run_stable_speech_training.py
+22
-14
No files found.
run_stable_speech_training.py
View file @
b09eba24
...
...
@@ -488,16 +488,18 @@ class DataCollatorEncodecWithPadding:
"""
feature_extractor
:
AutoFeatureExtractor
audio_column_name
:
str
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios
=
[
torch
.
tensor
(
feature
[
"labels"
]).
squeeze
()
for
feature
in
features
]
audios
=
[
torch
.
tensor
(
feature
[
self
.
audio_column_name
]).
squeeze
()
for
feature
in
features
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
input_features
=
{
self
.
feature_extractor
_input_name
:
audios
}
batch
=
self
.
feature_extractor
.
pad
(
input_features
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
return_attention_mask
=
True
)
batch
=
self
.
feature_extractor
(
audios
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
return_attention_mask
=
True
)
batch
[
self
.
feature_extractor_input_name
]
=
batch
[
self
.
feature_extractor_input_name
].
unsqueeze
(
1
)
# add mono-channel
batch
[
"padding_mask"
]
=
batch
.
pop
(
"attention_mask"
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
...
...
@@ -1032,7 +1034,7 @@ def main():
# Freeze Encoders
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
# TODO: remove
# TODO: remove
when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor
=
torch
.
tensor
([
accelerator
.
process_index
],
device
=
accelerator
.
device
)
gathered_tensor
=
accelerator
.
gather
(
test_tensor
)
...
...
@@ -1062,24 +1064,30 @@ def main():
text
=
batch
[
prompt_column_name
]
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
text
.
strip
())[
"input_ids"
]
# load audio
target_sample
=
batch
[
target_audio_column_name
]
arr
=
target_sample
[
"array"
]
labels
=
feature_extractor
(
arr
[:
min
(
len
(
arr
),
max_target_length
+
10
)],
sampling_rate
=
target_sample
[
"sampling_rate"
])
batch
[
"labels"
]
=
labels
[
"input_values"
]
# take length of raw audio waveform
batch
[
"target_length"
]
=
len
(
target_
sample
[
"array"
].
squeeze
())
batch
[
"target_length"
]
=
len
(
batch
[
target_
audio_column_name
]
[
"array"
].
squeeze
())
return
batch
with
accelerator
.
main_process_first
():
vectorized_datasets
=
raw_datasets
.
map
(
# this is a trick to avoid to rewrite the entire audio column which takes ages
tmp_datasets
=
raw_datasets
.
map
(
pass_through_processors
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
num_proc
=
num_workers
,
desc
=
"preprocess datasets"
,
# cache_file_names={"train": "/scratch/train.arrow", "eval":"/scratch/eval.arrow"} , # TODO: remove - specific to cluster
)
# only keep audio column from the raw datasets
# this is a trick to avoid to rewrite the entire audio column which takes ages
cols_to_remove
=
[
col
for
col
in
next
(
iter
(
raw_datasets
.
values
())).
column_names
if
col
!=
target_audio_column_name
]
for
split
in
raw_datasets
:
vectorized_datasets
[
split
]
=
concatenate_datasets
([
raw_datasets
[
split
].
remove_columns
(
cols_to_remove
),
tmp_datasets
[
split
]],
axis
=
1
)
# TODO: remove
logger
.
info
(
f
"Vectorized datasets
{
vectorized_datasets
}
"
)
with
accelerator
.
main_process_first
():
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
...
...
@@ -1150,7 +1158,7 @@ def main():
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
audio_column_name
=
target_audio_column_name
,
feature_extractor_input_name
=
feature_extractor_input_name
)
def
apply_audio_decoder
(
batch
):
len_audio
=
batch
.
pop
(
"len_audio"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment