Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
997bf5e6
Commit
997bf5e6
authored
Feb 23, 2024
by
Yoach Lacombe
Browse files
make audio encoding multi-gpus compatible
parent
e25b8ba0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
192 additions
and
144 deletions
+192
-144
run_stable_speech_training.py
run_stable_speech_training.py
+192
-144
No files found.
run_stable_speech_training.py
View file @
997bf5e6
...
...
@@ -32,6 +32,8 @@ from typing import Dict, List, Optional, Union
import
datasets
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
datasets
import
DatasetDict
,
load_dataset
,
Dataset
,
IterableDataset
,
interleave_datasets
,
concatenate_datasets
import
transformers
...
...
@@ -43,13 +45,15 @@ from transformers import (
HfArgumentParser
,
Seq2SeqTrainer
,
Seq2SeqTrainingArguments
,
set_seed
,
)
from
transformers.trainer_utils
import
get_last_checkpoint
,
is_main_process
from
transformers.utils
import
check_min_version
,
send_example_telemetry
from
transformers.utils.versions
import
require_version
from
transformers.integrations
import
is_wandb_available
from
accelerate
import
PartialState
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
...
...
@@ -214,10 +218,6 @@ class DataSeq2SeqTrainingArguments:
default
=
"audio"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the target audio data. Defaults to 'audio'"
},
)
conditional_audio_column_name
:
str
=
field
(
# TODO
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset column containing the conditional audio data. Defaults to 'audio'"
},
)
description_column_name
:
str
=
field
(
#TODO
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset column containing the text data. Defaults to 'None'."
},
...
...
@@ -311,6 +311,29 @@ class DataSeq2SeqTrainingArguments:
"help"
:
"id column name."
}
)
@
dataclass
class
DataCollatorEncodecWithPadding
:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
feature_extractor
:
AutoFeatureExtractor
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios
=
[
torch
.
tensor
(
feature
[
"labels"
]).
squeeze
()
for
feature
in
features
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
max_audio
=
max
(
len_audio
)
input_features
=
{
self
.
feature_extractor_input_name
:
audios
}
batch
=
self
.
feature_extractor
.
pad
(
input_features
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
return_attention_mask
=
True
)
batch
[
self
.
feature_extractor_input_name
]
=
batch
[
self
.
feature_extractor_input_name
].
unsqueeze
(
1
)
# add mono-channel
batch
[
"padding_mask"
]
=
batch
.
pop
(
"attention_mask"
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
return
batch
@
dataclass
...
...
@@ -437,6 +460,7 @@ def convert_dataset_str_to_list(
def
load_multiple_datasets
(
accelerator
:
Accelerator
,
dataset_names
:
Union
[
List
,
str
],
dataset_config_names
:
Union
[
List
,
str
],
metadata_dataset_names
:
Optional
[
str
]
=
None
,
...
...
@@ -463,51 +487,52 @@ def load_multiple_datasets(
all_datasets
=
[]
# iterate over the datasets we want to interleave
for
dataset_dict
in
tqdm
(
dataset_names_dict
,
desc
=
"Combining datasets..."
):
dataset
=
load_dataset
(
dataset_dict
[
"name"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
streaming
=
streaming
,
**
kwargs
,
)
dataset_features
=
dataset
.
features
.
keys
()
metadata_dataset_name
=
dataset_dict
[
"metadata_dataset_name"
]
if
metadata_dataset_name
is
not
None
:
metadata_dataset
=
load_dataset
(
metadata_dataset_name
,
with
accelerator
.
main_process_first
():
dataset
=
load_dataset
(
dataset_dict
[
"name"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
streaming
=
streaming
,
**
kwargs
,
)
dataset_features
=
dataset
.
features
.
keys
()
metadata_dataset_name
=
dataset_dict
[
"metadata_dataset_name"
]
if
metadata_dataset_name
is
not
None
:
metadata_dataset
=
load_dataset
(
metadata_dataset_name
,
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
streaming
=
streaming
,
**
kwargs
,
)
if
id_column_name
is
not
None
and
id_column_name
not
in
dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the dataset columns"
f
"- one of
{
', '
.
join
(
list
(
dataset
.
column_names
))
}
."
)
if
id_column_name
is
not
None
and
id_column_name
not
in
metadata_dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the metadata dataset columns"
f
"- one of
{
', '
.
join
(
list
(
metadata_dataset
.
column_names
))
}
."
)
elif
id_column_name
is
not
None
:
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
if
id_column_name
is
not
None
and
id_column_name
not
in
dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the dataset columns"
f
"- one of
{
', '
.
join
(
list
(
dataset
.
column_names
))
}
."
)
if
id_column_name
is
not
None
and
id_column_name
not
in
metadata_dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the metadata dataset columns"
f
"- one of
{
', '
.
join
(
list
(
metadata_dataset
.
column_names
))
}
."
)
elif
id_column_name
is
not
None
:
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
metadata_dataset
=
metadata_dataset
.
remove_columns
(
metadata_columns_to_remove
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
if
id_column_name
is
not
None
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'name'
]
}
"
)
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
metadata_dataset
=
metadata_dataset
.
remove_columns
(
metadata_columns_to_remove
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
dataset_features
=
dataset
.
features
.
keys
()
if
id_column_name
is
not
None
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'name'
]
}
"
)
dataset_features
=
dataset
.
features
.
keys
()
if
columns_to_keep
is
not
None
:
dataset
=
dataset
.
remove_columns
(
set
(
dataset_features
-
columns_to_keep
))
if
columns_to_keep
is
not
None
:
dataset
=
dataset
.
remove_columns
(
set
(
dataset_features
-
columns_to_keep
))
all_datasets
.
append
(
dataset
)
if
len
(
all_datasets
)
==
1
:
...
...
@@ -522,7 +547,8 @@ def load_multiple_datasets(
seed
=
seed
,
)
else
:
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
with
accelerator
.
main_process_first
():
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
return
interleaved_dataset
...
...
@@ -544,6 +570,8 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry
(
"run_stable_speech"
,
model_args
,
data_args
)
accelerator
=
Accelerator
()
# Detecting last checkpoint.
last_checkpoint
=
None
...
...
@@ -566,7 +594,7 @@ def main():
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
)
logger
.
setLevel
(
logging
.
INFO
if
is_main_process
(
training_args
.
local_rank
)
else
logging
.
WARN
)
logger
.
setLevel
(
logging
.
INFO
if
accelerator
.
is_main_process
else
logging
.
WARN
)
# Log on each process the small summary:
logger
.
warning
(
...
...
@@ -574,8 +602,9 @@ def main():
f
"distributed training:
{
training_args
.
parallel_mode
.
value
==
'distributed'
}
, 16-bits training:
{
training_args
.
fp16
}
"
)
# Set the verbosity to info of the Transformers logger (on main process only):
if
is_main_process
(
training_args
.
local_rank
)
:
if
accelerator
.
is_main_process
:
transformers
.
utils
.
logging
.
set_verbosity_info
()
logger
.
info
(
"Training/evaluation parameters %s"
,
training_args
)
# Set seed before initializing model.
...
...
@@ -585,71 +614,51 @@ def main():
raw_datasets
=
DatasetDict
()
num_workers
=
data_args
.
preprocessing_num_workers
columns_to_keep
=
[
data_args
.
target_audio_column_name
,
data_args
.
prompt_column_name
]
columns_to_keep
=
{
"target_audio_column_name"
:
data_args
.
target_audio_column_name
,
"prompt_column_name"
:
data_args
.
prompt_column_name
}
if
data_args
.
description_column_name
is
not
None
:
columns_to_keep
.
append
(
data_args
.
description_column_name
)
if
data_args
.
conditional_audio_column_name
is
not
None
:
columns_to_keep
.
append
(
data_args
.
conditional_audio_column_name
)
columns_to_keep
[
"description_column_nam"
]
=
data_args
.
description_column_name
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
train_dataset_name
,
data_args
.
train_dataset_config_name
,
data_args
.
train_metadata_dataset_name
,
metadata_dataset_names
=
data_args
.
train_metadata_dataset_name
,
splits
=
data_args
.
train_split_name
,
dataset_samples
=
data_args
.
train_dataset_samples
,
seed
=
training_args
.
seed
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
,
columns_to_keep
=
columns_to_keep
.
values
()
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
if
data_args
.
target_audio_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--target_audio_column_name '
{
data_args
.
target_audio_column_name
}
' not found in dataset '
{
data_args
.
train_dataset_name
}
'."
" Make sure to set `--target_audio_column_name` to the correct audio column - one of"
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
description_column_name
is
not
None
and
data_args
.
description_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--description_column_name
{
data_args
.
description_column_name
}
not found in dataset '
{
data_args
.
train_dataset_name
}
'. "
"Make sure to set `--description_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
prompt_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--description_column_name
{
data_args
.
prompt_column_name
}
not found in dataset '
{
data_args
.
train_dataset_name
}
'. "
"Make sure to set `--prompt_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
conditional_audio_column_name
is
not
None
and
data_args
.
conditional_audio_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--conditional_audio_column_name
{
data_args
.
conditional_audio_column_name
}
not found in dataset '
{
data_args
.
train_dataset_name
}
'. "
"Make sure to set `--conditional_audio_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
for
key
in
columns_to_keep
:
if
columns_to_keep
[
key
]
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--
{
key
}
'
{
columns_to_keep
[
key
]
}
' not found in dataset '
{
data_args
.
train_dataset_name
}
'."
f
" Make sure to set `--
{
key
}
` to the correct audio column - one of"
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
max_train_samples
is
not
None
:
raw_datasets
[
"train"
]
=
raw_datasets
[
"train"
].
select
(
range
(
data_args
.
max_train_samples
))
if
training_args
.
do_eval
:
raw_datasets
[
"eval"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
eval_dataset_name
if
data_args
.
eval_dataset_name
else
data_args
.
train_dataset_name
,
data_args
.
eval_dataset_config_name
if
data_args
.
eval_dataset_config_name
else
data_args
.
train_dataset_config_name
,
data_args
.
eval_metadata_dataset_name
,
data_args
.
eval_dataset_config_name
if
data_args
.
eval_dataset_config_name
else
data_args
.
train_dataset_config_name
,
metadata_dataset_names
=
data_args
.
eval_metadata_dataset_name
,
splits
=
data_args
.
eval_split_name
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
,
columns_to_keep
=
columns_to_keep
.
values
()
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
...
...
@@ -657,9 +666,8 @@ def main():
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
# TODO: is is the right way to do ?
# 3. Next, let's load the config as we might need it to create
# load config
# 2. Next, let's load the config as we might need it to create
# load config TODO: add the option to create the config from scratch
config
=
StableSpeechConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
...
...
@@ -673,8 +681,7 @@ def main():
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
model
.
config
.
decoder_start_token_id
,
})
# 4. Now we can instantiate the feature extractor, tokenizers and model
# 3. Now we can instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab.
...
...
@@ -692,16 +699,24 @@ def main():
cache_dir
=
model_args
.
cache_dir
,
token
=
data_args
.
token
,
trust_remote_code
=
data_args
.
trust_remote_code
,
use_fast
=
model_args
.
use_fast_tokenizer
,
)
# load description tokenizer
description_tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_args
.
description_tokenizer_name
or
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
token
=
data_args
.
token
,
trust_remote_code
=
data_args
.
trust_remote_code
,
use_fast
=
model_args
.
use_fast_tokenizer
,
)
if
model_args
.
use_fast_tokenizer
:
logger
.
warning
(
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
)
prompt_tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
True
description_tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
True
# create model + TODO: not from_pretrained probably
model
=
StableSpeechForConditionalGeneration
.
from_pretrained
(
model_args
.
model_name_or_path
,
...
...
@@ -711,46 +726,31 @@ def main():
trust_remote_code
=
data_args
.
trust_remote_code
,
)
# take audio_encoder_feature_extractor
audio_encoder_feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
model
.
config
.
audio_encoder
.
_name_or_path
,
)
#
5
. Now we preprocess the datasets including loading the audio, resampling and normalization
#
4
. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
data_args
.
target_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
audio_encoder_feature_extractor
.
sampling_rate
)
)
if
data_args
.
conditional_audio_column_name
is
not
None
:
raw_datasets
=
raw_datasets
.
cast_column
(
data_args
.
conditional_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
feature_extractor
.
sampling_rate
)
)
# derive max & min input length for sample rate & max duration
max_target_length
=
data_args
.
max_duration_in_seconds
*
feature_extractor
.
sampling_rate
min_target_length
=
data_args
.
min_duration_in_seconds
*
feature_extractor
.
sampling_rate
sampling_rate
=
feature_extractor
.
sampling_rate
max_target_length
=
data_args
.
max_duration_in_seconds
*
sampling_rate
min_target_length
=
data_args
.
min_duration_in_seconds
*
sampling_rate
target_audio_column_name
=
data_args
.
target_audio_column_name
conditional_audio_column_name
=
data_args
.
conditional_audio_column_name
description_column_name
=
data_args
.
description_column_name
prompt_column_name
=
data_args
.
prompt_column_name
feature_extractor_input_name
=
feature_extractor
.
model_input_names
[
0
]
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
target_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
sampling_rate
)
)
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the t
arg
ets.
# We need to read the audio files as arrays and tokenize the te
x
ts.
def
pass_through_processors
(
batch
):
# load audio
if
conditional_audio_column_name
is
not
None
:
sample
=
batch
[
target_audio_column_name
]
inputs
=
feature_extractor
(
sample
[
"array"
],
sampling_rate
=
sample
[
"sampling_rate"
])
batch
[
feature_extractor_input_name
]
=
getattr
(
inputs
,
feature_extractor_input_name
)[
0
]
if
description_column_name
is
not
None
:
text
=
batch
[
description_column_name
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
.
strip
())[
"input_ids"
]
...
...
@@ -761,14 +761,14 @@ def main():
# load audio
target_sample
=
batch
[
target_audio_column_name
]
labels
=
audio_encoder_
feature_extractor
(
target_sample
[
"array"
],
sampling_rate
=
target_sample
[
"sampling_rate"
])
labels
=
feature_extractor
(
target_sample
[
"array"
],
sampling_rate
=
target_sample
[
"sampling_rate"
])
batch
[
"labels"
]
=
labels
[
"input_values"
]
# take length of raw audio waveform
batch
[
"target_length"
]
=
len
(
target_sample
[
"array"
].
squeeze
())
return
batch
with
training_args
.
main_process_first
(
desc
=
"dataset map preprocessing"
):
with
accelerator
.
main_process_first
():
vectorized_datasets
=
raw_datasets
.
map
(
pass_through_processors
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
...
...
@@ -785,34 +785,81 @@ def main():
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
)
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
logger
.
info
(
"*** Encode target audio with encodec ***"
)
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
def
apply_audio_decoder
(
batch
):
labels
=
audio_decoder
.
encode
(
torch
.
tensor
(
batch
[
"labels"
]).
to
(
audio_decoder
.
device
))[
"audio_codes"
]
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
model
.
generation_config
.
decoder_start_token_id
,
model
.
generation_config
.
max_length
+
1
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
# the first timestamp is associated to a row full of BOS, let's get rid of it
batch
[
"labels"
]
=
labels
[:,
1
:]
return
batch
with
training_args
.
main_process_first
(
desc
=
"audio target preprocessing"
):
# for now on CPU
# TODO: enrich for GPU
vectorized_datasets
=
vectorized_datasets
.
map
(
apply_audio_decoder
,
num_proc
=
num_workers
,
desc
=
"preprocess datasets"
,
len_audio
=
batch
.
pop
(
"len_audio"
)
audio_decoder
.
to
(
batch
[
"input_values"
].
device
).
eval
()
labels
=
audio_decoder
.
encode
(
**
batch
)[
"audio_codes"
]
output
=
{}
output
[
"len_audio"
]
=
len_audio
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output
[
"labels"
]
=
labels
.
squeeze
(
0
).
transpose
(
1
,
2
)
output
[
"ratio"
]
=
torch
.
ones_like
(
len_audio
)
*
labels
.
shape
[
-
1
]
/
len_audio
.
max
()
return
output
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
vectorized_datasets
[
split
],
batch_size
=
training_args
.
per_device_eval_batch_size
,
collate_fn
=
encoder_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
pin_memory
=
True
,
)
# TODO: will it work on GPU ? unmerged for now https://github.com/huggingface/accelerate/pull/2433
# for split in vectorized_datasets:
# with distributed_state.split_between_processes(vectorized_datasets[split]["labels"]) as input_labels:
# result = audio_decoder(input_labels)
data_loader
=
accelerator
.
prepare
(
data_loader
)
all_generated_labels
=
[]
all_ratios
=
[]
all_lens
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
generate_labels
=
apply_audio_decoder
(
batch
)
generate_labels
=
accelerator
.
pad_across_processes
(
generate_labels
,
dim
=
1
,
pad_index
=
0
)
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
all_generated_labels
.
extend
(
generate_labels
[
"labels"
].
cpu
())
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
def
postprocess_dataset
(
sample
,
idx
):
# (1, seq_len, codebooks, bsz)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
model
.
generation_config
.
decoder_start_token_id
,
model
.
generation_config
.
max_length
+
model
.
decoder
.
config
.
num_codebooks
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample
[
"labels"
]
=
labels
[:,
1
:
len_
]
return
sample
# TODO: done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
num_workers
,
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
)
accelerator
.
free_memory
()
del
generate_labels
if
data_args
.
add_audio_samples_to_wandb
and
"wandb"
in
training_args
.
report_to
:
if
is_wandb_available
():
...
...
@@ -827,6 +874,7 @@ def main():
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
if
data_args
.
preprocessing_only
:
# TODO: save to disk in this step instead of something else ??
logger
.
info
(
f
"Data preprocessing finished. Files cached at
{
vectorized_datasets
.
cache_files
}
"
)
return
...
...
@@ -865,9 +913,9 @@ def main():
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
with
training_args
.
main_process_first
():
with
accelerator
.
main_process_first
():
# only the main process saves them
if
is_main_process
(
training_args
.
local_rank
)
:
if
accelerator
.
is_main_process
:
# save feature extractor, tokenizer and config
if
model_args
.
prompt_tokenizer_name
is
None
and
model_args
.
description_tokenizer_name
or
(
model_args
.
prompt_tokenizer_name
==
model_args
.
description_tokenizer_name
):
prompt_tokenizer
.
save_pretrained
(
training_args
.
output_dir
)
...
...
@@ -936,7 +984,7 @@ def main():
audios
=
predictions
[
"audio"
]
# log the table to wandb
self
.
_wandb
.
log
({
"sample_songs"
:
[
self
.
_wandb
.
Audio
(
audio
,
caption
=
text
,
sample_rate
=
audio_encoder_feature_extractor
.
sampling_rate
)
for
(
audio
,
text
)
in
zip
(
audios
,
texts
)]})
self
.
_wandb
.
log
({
"sample_songs"
:
[
self
.
_wandb
.
Audio
(
audio
,
caption
=
text
,
sample_rate
=
sampling_rate
)
for
(
audio
,
text
)
in
zip
(
audios
,
texts
)]})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment