Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
5e2041eb
Commit
5e2041eb
authored
Mar 26, 2024
by
yoach@huggingface.co
Browse files
make smarter audio encoding in terms of RAM usage
parent
84e0def5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
70 deletions
+67
-70
run_stable_speech_training.py
run_stable_speech_training.py
+67
-70
No files found.
run_stable_speech_training.py
View file @
5e2041eb
...
@@ -444,6 +444,12 @@ class DataTrainingArguments:
...
@@ -444,6 +444,12 @@ class DataTrainingArguments:
"help"
:
"If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
"help"
:
"If set, will save the dataset to this path if this is an empyt folder. If not empty, will load the datasets from it."
}
}
)
)
temporary_save_to_disk
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"Temporarily save audio labels here."
}
)
pad_to_multiple_of
:
Optional
[
int
]
=
field
(
pad_to_multiple_of
:
Optional
[
int
]
=
field
(
default
=
2
,
default
=
2
,
metadata
=
{
metadata
=
{
...
@@ -490,6 +496,7 @@ class DataCollatorEncodecWithPadding:
...
@@ -490,6 +496,7 @@ class DataCollatorEncodecWithPadding:
feature_extractor
:
AutoFeatureExtractor
feature_extractor
:
AutoFeatureExtractor
audio_column_name
:
str
audio_column_name
:
str
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
max_length
:
Optional
[
int
]
=
None
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
...
@@ -497,6 +504,8 @@ class DataCollatorEncodecWithPadding:
...
@@ -497,6 +504,8 @@ class DataCollatorEncodecWithPadding:
# different padding methods
# different padding methods
audios
=
[
feature
[
self
.
audio_column_name
][
"array"
]
for
feature
in
features
]
audios
=
[
feature
[
self
.
audio_column_name
][
"array"
]
for
feature
in
features
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
if
self
.
max_length
is
not
None
:
audios
=
[
audio
[:
min
(
len
(
audio
),
self
.
max_length
+
10
)]
for
audio
in
audios
]
batch
=
self
.
feature_extractor
(
audios
,
return_tensors
=
"pt"
,
padding
=
"longest"
)
batch
=
self
.
feature_extractor
(
audios
,
return_tensors
=
"pt"
,
padding
=
"longest"
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
...
@@ -1030,14 +1039,7 @@ def main():
...
@@ -1030,14 +1039,7 @@ def main():
# Freeze Encoders
# Freeze Encoders
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor
=
torch
.
tensor
([
accelerator
.
process_index
],
device
=
accelerator
.
device
)
gathered_tensor
=
accelerator
.
gather
(
test_tensor
)
print
(
"gathered_tensor"
,
gathered_tensor
)
accelerator
.
wait_for_everyone
()
if
not
dataset_was_precomputed
:
if
not
dataset_was_precomputed
:
# Filter on text length
# Filter on text length
if
description_column_name
is
not
None
:
if
description_column_name
is
not
None
:
...
@@ -1049,53 +1051,28 @@ def main():
...
@@ -1049,53 +1051,28 @@ def main():
input_columns
=
[
description_column_name
],
input_columns
=
[
description_column_name
],
)
)
# Preprocessing the datasets.
# Preprocessing the dataset.
# We need to read the audio files as arrays and tokenize the texts.
# We need to tokenize the texts.
def
pass_through_processors
(
batch
):
def
pass_through_processors
(
description
,
prompt
):
# load audio
batch
=
{}
if
description_column_name
is
not
None
:
text
=
batch
[
description_column_name
]
batch
[
"input_ids"
]
=
description_tokenizer
(
description
.
strip
())[
"input_ids"
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
.
strip
())[
"input_ids"
]
# TODO: add possibility to train without description column
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
prompt
.
strip
())[
"input_ids"
]
if
prompt_column_name
is
not
None
:
text
=
batch
[
prompt_column_name
]
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
text
.
strip
())[
"input_ids"
]
# take length of raw audio waveform
batch
[
"target_length"
]
=
len
(
batch
[
target_audio_column_name
][
"array"
].
squeeze
())
return
batch
return
batch
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
# this is a trick to avoid to rewrite the entire audio column which takes ages
# this is a trick to avoid to rewrite the entire audio column which takes ages
tmp
_datasets
=
raw_datasets
.
map
(
vectorized
_datasets
=
raw_datasets
.
map
(
pass_through_processors
,
pass_through_processors
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
input_columns
=
[
description_column_name
,
prompt_column_name
],
num_proc
=
num_workers
,
num_proc
=
num_workers
,
desc
=
"preprocess datasets"
,
desc
=
"preprocess datasets"
,
# cache_file_names={"train": "/scratch/train.arrow", "eval":"/scratch/eval.arrow"} , # TODO: remove - specific to cluster
)
# only keep audio column from the raw datasets
# this is a trick to avoid to rewrite the entire audio column which takes ages
cols_to_remove
=
[
col
for
col
in
next
(
iter
(
raw_datasets
.
values
())).
column_names
if
col
!=
target_audio_column_name
]
for
split
in
raw_datasets
:
raw_datasets
[
split
]
=
concatenate_datasets
([
raw_datasets
[
split
].
remove_columns
(
cols_to_remove
),
tmp_datasets
[
split
]],
axis
=
1
)
with
accelerator
.
main_process_first
():
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets
=
raw_datasets
.
filter
(
is_audio_in_length_range
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
)
# We use Accelerate to perform distributed inference
# We use Accelerate to perform distributed inference
# T5 doesn't support fp16
# T5 doesn't support fp16
autocast_kwargs
=
AutocastKwargs
(
enabled
=
(
mixed_precision
!=
"fp16"
))
autocast_kwargs
=
AutocastKwargs
(
enabled
=
(
mixed_precision
!=
"fp16"
))
...
@@ -1118,13 +1095,15 @@ def main():
...
@@ -1118,13 +1095,15 @@ def main():
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
model
.
text_encoder
.
to
(
batch
[
"input_ids"
].
device
)
model
.
text_encoder
.
to
(
batch
[
"input_ids"
].
device
)
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
encoder_outputs
=
model
.
text_encoder
(
input_ids
=
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
])
with
torch
.
no_grad
():
encoder_outputs
=
model
.
text_encoder
(
input_ids
=
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
])
encoder_outputs
=
accelerator
.
pad_across_processes
(
encoder_outputs
,
dim
=
1
,
pad_index
=
prompt_tokenizer
.
pad_token_id
)
encoder_outputs
=
accelerator
.
pad_across_processes
(
encoder_outputs
,
dim
=
1
,
pad_index
=
prompt_tokenizer
.
pad_token_id
)
encoder_outputs
=
accelerator
.
gather_for_metrics
(
encoder_outputs
)
encoder_outputs
=
accelerator
.
gather_for_metrics
(
encoder_outputs
)
lengths
=
accelerator
.
gather_for_metrics
(
batch
[
"len_input_ids"
])
lengths
=
accelerator
.
gather_for_metrics
(
batch
[
"len_input_ids"
])
all_encoder_outputs
.
extend
(
encoder_outputs
.
last_hidden_state
.
to
(
"cpu"
))
if
accelerator
.
is_main_process
:
all_encoder_lengths
.
extend
(
lengths
.
to
(
"cpu"
))
all_encoder_outputs
.
extend
(
encoder_outputs
.
last_hidden_state
.
to
(
"cpu"
))
all_encoder_lengths
.
extend
(
lengths
.
to
(
"cpu"
))
def
postprocess_dataset
(
input_ids
,
idx
):
def
postprocess_dataset
(
input_ids
,
idx
):
output
=
{
"encoder_outputs"
:
BaseModelOutput
(
last_hidden_state
=
all_encoder_outputs
[
idx
][:
all_encoder_lengths
[
idx
]])}
output
=
{
"encoder_outputs"
:
BaseModelOutput
(
last_hidden_state
=
all_encoder_outputs
[
idx
][:
all_encoder_lengths
[
idx
]])}
...
@@ -1140,6 +1119,7 @@ def main():
...
@@ -1140,6 +1119,7 @@ def main():
with_indices
=
True
,
with_indices
=
True
,
writer_batch_size
=
100
,
writer_batch_size
=
100
,
)
)
accelerator
.
wait_for_everyone
()
accelerator
.
free_memory
()
accelerator
.
free_memory
()
del
data_loader
,
all_encoder_outputs
,
all_encoder_lengths
del
data_loader
,
all_encoder_outputs
,
all_encoder_lengths
...
@@ -1153,7 +1133,7 @@ def main():
...
@@ -1153,7 +1133,7 @@ def main():
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
audio_decoder
=
model
.
audio_encoder
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
audio_column_name
=
target_audio_column_name
,
feature_extractor_input_name
=
feature_extractor_input_name
)
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
audio_column_name
=
target_audio_column_name
,
feature_extractor_input_name
=
feature_extractor_input_name
,
max_length
=
max_target_length
)
def
apply_audio_decoder
(
batch
):
def
apply_audio_decoder
(
batch
):
len_audio
=
batch
.
pop
(
"len_audio"
)
len_audio
=
batch
.
pop
(
"len_audio"
)
...
@@ -1169,7 +1149,7 @@ def main():
...
@@ -1169,7 +1149,7 @@ def main():
for
split
in
vectorized_datasets
:
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
data_loader
=
DataLoader
(
vectorized
_datasets
[
split
],
raw
_datasets
[
split
],
batch_size
=
training_args
.
audio_encode_per_device_eval_batch_size
,
batch_size
=
training_args
.
audio_encode_per_device_eval_batch_size
,
collate_fn
=
encoder_data_collator
,
collate_fn
=
encoder_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
num_workers
=
training_args
.
dataloader_num_workers
,
...
@@ -1185,22 +1165,31 @@ def main():
...
@@ -1185,22 +1165,31 @@ def main():
generate_labels
=
accelerator
.
pad_across_processes
(
generate_labels
,
dim
=
1
,
pad_index
=
0
)
generate_labels
=
accelerator
.
pad_across_processes
(
generate_labels
,
dim
=
1
,
pad_index
=
0
)
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
all_generated_labels
.
extend
(
generate_labels
[
"labels"
].
cpu
())
if
accelerator
.
is_main_process
:
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_generated_labels
.
extend
(
generate_labels
[
"labels"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
().
squeeze
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
().
squeeze
())
# (1, codebooks, seq_len) where seq_len=1
# (1, codebooks, seq_len) where seq_len=1
eos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_eos_token_id
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
bos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_bos_token_id
def
postprocess_dataset
(
input_ids
,
prompt_input_ids
,
idx
):
if
accelerator
.
is_main_process
:
tmp_labels
=
Dataset
.
from_dict
({
"labels"
:
all_generated_labels
,
"ratios"
:
all_ratios
,
"target_length"
:
all_lens
})
tmp_labels
.
save_to_disk
(
data_args
.
temporary_save_to_disk
,
num_proc
=
data_args
.
preprocessing_num_workers
)
accelerator
.
wait_for_everyone
()
del
all_generated_labels
tmp_labels
=
datasets
.
load_from_disk
(
data_args
.
temporary_save_to_disk
)
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
concatenate_datasets
([
vectorized_datasets
[
split
],
tmp_labels
],
axis
=
1
)
def
postprocess_dataset
(
labels
,
target_length
,
ratio
):
# (1, codebooks, seq_len)
# (1, codebooks, seq_len)
labels
=
all_generated_
labels
[
idx
]
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
labels
=
torch
.
tensor
(
labels
)
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
len_
=
int
(
all_
ratio
s
[
idx
]
*
all_lens
[
idx
]
)
len_
=
int
(
ratio
*
target_length
)
labels
=
labels
[:,
:,
:
len_
]
labels
=
labels
[:,
:,
:
len_
]
# labels = labels[:, :, :(len_)%10+500] # TODO: change
# add bos
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
...
@@ -1210,7 +1199,6 @@ def main():
...
@@ -1210,7 +1199,6 @@ def main():
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
,
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
,
num_codebooks
=
num_codebooks
)
num_codebooks
=
num_codebooks
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# to take care of EOS
# we want labels to look like this:
# we want labels to look like this:
...
@@ -1223,29 +1211,38 @@ def main():
...
@@ -1223,29 +1211,38 @@ def main():
# the first timestamp is associated to a row full of BOS, let's get rid of it
# the first timestamp is associated to a row full of BOS, let's get rid of it
# we also remove the last timestampts (full of PAD)
# we also remove the last timestampts (full of PAD)
output
=
{
"labels"
:
labels
[:,
1
:].
cpu
()}
output
=
{
"labels"
:
labels
[:,
1
:].
cpu
()}
output
[
"input_ids"
]
=
input_ids
output
[
"prompt_input_ids"
]
=
prompt_input_ids
return
output
return
output
# TODO(YL): done multiple times, how to deal with it.
# TODO(YL): done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
postprocess_dataset
,
num_proc
=
1
,
# this one is resource consuming if many processor.
num_proc
=
data_args
.
preprocessing_num_workers
,
# this one is resource consuming if many processor.
input_columns
=
[
"input_ids"
,
"prompt_input_ids"
],
input_columns
=
[
"labels"
,
"target_length"
,
"ratios"
],
remove_columns
=
[
"ratios"
],
desc
=
"Postprocessing labeling"
,
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
writer_batch_size
=
100
,
)
)
accelerator
.
free_memory
()
accelerator
.
free_memory
()
del
generate_labels
,
all_generated_labels
,
all_lens
,
all_ratios
del
generate_labels
,
all_lens
,
all_ratios
with
accelerator
.
main_process_first
():
def
is_audio_in_length_range
(
length
):
return
length
>
min_target_length
and
length
<
max_target_length
# filter data that is shorter than min_target_length
vectorized_datasets
=
vectorized_datasets
.
filter
(
is_audio_in_length_range
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
)
if
accelerator
.
is_main_process
:
vectorized_datasets
.
save_to_disk
(
data_args
.
save_to_disk
,
num_proc
=
data_args
.
preprocessing_num_workers
)
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
logger
.
info
(
f
"Dataset saved at
{
data_args
.
save_to_disk
}
"
)
# for large datasets it is advised to run the preprocessing on a
# for large datasets it is advised to run the preprocessing on a
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment