Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
997bf5e6
"docs/vscode:/vscode.git/clone" did not exist on "2e1d2d7e66c33fdd2b58aaf03a9893dbe593a3a3"
Commit
997bf5e6
authored
Feb 23, 2024
by
Yoach Lacombe
Browse files
make audio encoding multi-gpus compatible
parent
e25b8ba0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
192 additions
and
144 deletions
+192
-144
run_stable_speech_training.py
run_stable_speech_training.py
+192
-144
No files found.
run_stable_speech_training.py
View file @
997bf5e6
...
@@ -32,6 +32,8 @@ from typing import Dict, List, Optional, Union
...
@@ -32,6 +32,8 @@ from typing import Dict, List, Optional, Union
import
datasets
import
datasets
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
datasets
import
DatasetDict
,
load_dataset
,
Dataset
,
IterableDataset
,
interleave_datasets
,
concatenate_datasets
from
datasets
import
DatasetDict
,
load_dataset
,
Dataset
,
IterableDataset
,
interleave_datasets
,
concatenate_datasets
import
transformers
import
transformers
...
@@ -43,13 +45,15 @@ from transformers import (
...
@@ -43,13 +45,15 @@ from transformers import (
HfArgumentParser
,
HfArgumentParser
,
Seq2SeqTrainer
,
Seq2SeqTrainer
,
Seq2SeqTrainingArguments
,
Seq2SeqTrainingArguments
,
set_seed
,
)
)
from
transformers.trainer_utils
import
get_last_checkpoint
,
is_main_process
from
transformers.trainer_utils
import
get_last_checkpoint
,
is_main_process
from
transformers.utils
import
check_min_version
,
send_example_telemetry
from
transformers.utils
import
check_min_version
,
send_example_telemetry
from
transformers.utils.versions
import
require_version
from
transformers.utils.versions
import
require_version
from
transformers.integrations
import
is_wandb_available
from
transformers.integrations
import
is_wandb_available
from
accelerate
import
PartialState
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
...
@@ -214,10 +218,6 @@ class DataSeq2SeqTrainingArguments:
...
@@ -214,10 +218,6 @@ class DataSeq2SeqTrainingArguments:
default
=
"audio"
,
default
=
"audio"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the target audio data. Defaults to 'audio'"
},
metadata
=
{
"help"
:
"The name of the dataset column containing the target audio data. Defaults to 'audio'"
},
)
)
conditional_audio_column_name
:
str
=
field
(
# TODO
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset column containing the conditional audio data. Defaults to 'audio'"
},
)
description_column_name
:
str
=
field
(
#TODO
description_column_name
:
str
=
field
(
#TODO
default
=
None
,
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset column containing the text data. Defaults to 'None'."
},
metadata
=
{
"help"
:
"The name of the dataset column containing the text data. Defaults to 'None'."
},
...
@@ -311,6 +311,29 @@ class DataSeq2SeqTrainingArguments:
...
@@ -311,6 +311,29 @@ class DataSeq2SeqTrainingArguments:
"help"
:
"id column name."
"help"
:
"id column name."
}
}
)
)
@
dataclass
class
DataCollatorEncodecWithPadding
:
"""
Data collator that will dynamically pad the inputs received to the longest sequence in the batch.
"""
feature_extractor
:
AutoFeatureExtractor
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
# split inputs and labels since they have to be of different lengths and need
# different padding methods
audios
=
[
torch
.
tensor
(
feature
[
"labels"
]).
squeeze
()
for
feature
in
features
]
len_audio
=
[
len
(
audio
)
for
audio
in
audios
]
max_audio
=
max
(
len_audio
)
input_features
=
{
self
.
feature_extractor_input_name
:
audios
}
batch
=
self
.
feature_extractor
.
pad
(
input_features
,
return_tensors
=
"pt"
,
padding
=
"longest"
,
return_attention_mask
=
True
)
batch
[
self
.
feature_extractor_input_name
]
=
batch
[
self
.
feature_extractor_input_name
].
unsqueeze
(
1
)
# add mono-channel
batch
[
"padding_mask"
]
=
batch
.
pop
(
"attention_mask"
)
batch
[
"len_audio"
]
=
torch
.
tensor
(
len_audio
).
unsqueeze
(
1
)
return
batch
@
dataclass
@
dataclass
...
@@ -437,6 +460,7 @@ def convert_dataset_str_to_list(
...
@@ -437,6 +460,7 @@ def convert_dataset_str_to_list(
def
load_multiple_datasets
(
def
load_multiple_datasets
(
accelerator
:
Accelerator
,
dataset_names
:
Union
[
List
,
str
],
dataset_names
:
Union
[
List
,
str
],
dataset_config_names
:
Union
[
List
,
str
],
dataset_config_names
:
Union
[
List
,
str
],
metadata_dataset_names
:
Optional
[
str
]
=
None
,
metadata_dataset_names
:
Optional
[
str
]
=
None
,
...
@@ -463,51 +487,52 @@ def load_multiple_datasets(
...
@@ -463,51 +487,52 @@ def load_multiple_datasets(
all_datasets
=
[]
all_datasets
=
[]
# iterate over the datasets we want to interleave
# iterate over the datasets we want to interleave
for
dataset_dict
in
tqdm
(
dataset_names_dict
,
desc
=
"Combining datasets..."
):
for
dataset_dict
in
tqdm
(
dataset_names_dict
,
desc
=
"Combining datasets..."
):
dataset
=
load_dataset
(
with
accelerator
.
main_process_first
():
dataset_dict
[
"name"
],
dataset
=
load_dataset
(
dataset_dict
[
"config"
],
dataset_dict
[
"name"
],
split
=
dataset_dict
[
"split"
],
streaming
=
streaming
,
**
kwargs
,
)
dataset_features
=
dataset
.
features
.
keys
()
metadata_dataset_name
=
dataset_dict
[
"metadata_dataset_name"
]
if
metadata_dataset_name
is
not
None
:
metadata_dataset
=
load_dataset
(
metadata_dataset_name
,
dataset_dict
[
"config"
],
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
split
=
dataset_dict
[
"split"
],
streaming
=
streaming
,
streaming
=
streaming
,
**
kwargs
,
**
kwargs
,
)
)
dataset_features
=
dataset
.
features
.
keys
()
metadata_dataset_name
=
dataset_dict
[
"metadata_dataset_name"
]
if
metadata_dataset_name
is
not
None
:
metadata_dataset
=
load_dataset
(
metadata_dataset_name
,
dataset_dict
[
"config"
],
split
=
dataset_dict
[
"split"
],
streaming
=
streaming
,
**
kwargs
,
)
if
id_column_name
is
not
None
and
id_column_name
not
in
dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the dataset columns"
f
"- one of
{
', '
.
join
(
list
(
dataset
.
column_names
))
}
."
)
if
id_column_name
is
not
None
and
id_column_name
not
in
metadata_dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the metadata dataset columns"
f
"- one of
{
', '
.
join
(
list
(
metadata_dataset
.
column_names
))
}
."
)
elif
id_column_name
is
not
None
:
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
if
id_column_name
is
not
None
and
id_column_name
not
in
dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the dataset columns"
f
"- one of
{
', '
.
join
(
list
(
dataset
.
column_names
))
}
."
)
if
id_column_name
is
not
None
and
id_column_name
not
in
metadata_dataset
.
column_names
:
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the metadata dataset columns"
f
"- one of
{
', '
.
join
(
list
(
metadata_dataset
.
column_names
))
}
."
)
elif
id_column_name
is
not
None
:
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
metadata_columns_to_remove
=
set
(
metadata_dataset
.
column_names
).
intersection
(
set
(
dataset
.
column_names
))
metadata_dataset
=
metadata_dataset
.
remove_columns
(
metadata_columns_to_remove
)
metadata_dataset
=
metadata_dataset
.
remove_columns
(
metadata_columns_to_remove
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
if
id_column_name
is
not
None
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'name'
]
}
"
)
dataset_features
=
dataset
.
features
.
keys
()
if
id_column_name
is
not
None
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'name'
]
}
"
)
dataset_features
=
dataset
.
features
.
keys
()
if
columns_to_keep
is
not
None
:
if
columns_to_keep
is
not
None
:
dataset
=
dataset
.
remove_columns
(
set
(
dataset_features
-
columns_to_keep
))
dataset
=
dataset
.
remove_columns
(
set
(
dataset_features
-
columns_to_keep
))
all_datasets
.
append
(
dataset
)
all_datasets
.
append
(
dataset
)
if
len
(
all_datasets
)
==
1
:
if
len
(
all_datasets
)
==
1
:
...
@@ -522,7 +547,8 @@ def load_multiple_datasets(
...
@@ -522,7 +547,8 @@ def load_multiple_datasets(
seed
=
seed
,
seed
=
seed
,
)
)
else
:
else
:
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
with
accelerator
.
main_process_first
():
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
return
interleaved_dataset
return
interleaved_dataset
...
@@ -544,6 +570,8 @@ def main():
...
@@ -544,6 +570,8 @@ def main():
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
# information sent is the one passed as arguments along with your Python/PyTorch versions.
# information sent is the one passed as arguments along with your Python/PyTorch versions.
send_example_telemetry
(
"run_stable_speech"
,
model_args
,
data_args
)
send_example_telemetry
(
"run_stable_speech"
,
model_args
,
data_args
)
accelerator
=
Accelerator
()
# Detecting last checkpoint.
# Detecting last checkpoint.
last_checkpoint
=
None
last_checkpoint
=
None
...
@@ -566,7 +594,7 @@ def main():
...
@@ -566,7 +594,7 @@ def main():
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
datefmt
=
"%m/%d/%Y %H:%M:%S"
,
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
handlers
=
[
logging
.
StreamHandler
(
sys
.
stdout
)],
)
)
logger
.
setLevel
(
logging
.
INFO
if
is_main_process
(
training_args
.
local_rank
)
else
logging
.
WARN
)
logger
.
setLevel
(
logging
.
INFO
if
accelerator
.
is_main_process
else
logging
.
WARN
)
# Log on each process the small summary:
# Log on each process the small summary:
logger
.
warning
(
logger
.
warning
(
...
@@ -574,8 +602,9 @@ def main():
...
@@ -574,8 +602,9 @@ def main():
f
"distributed training:
{
training_args
.
parallel_mode
.
value
==
'distributed'
}
, 16-bits training:
{
training_args
.
fp16
}
"
f
"distributed training:
{
training_args
.
parallel_mode
.
value
==
'distributed'
}
, 16-bits training:
{
training_args
.
fp16
}
"
)
)
# Set the verbosity to info of the Transformers logger (on main process only):
# Set the verbosity to info of the Transformers logger (on main process only):
if
is_main_process
(
training_args
.
local_rank
)
:
if
accelerator
.
is_main_process
:
transformers
.
utils
.
logging
.
set_verbosity_info
()
transformers
.
utils
.
logging
.
set_verbosity_info
()
logger
.
info
(
"Training/evaluation parameters %s"
,
training_args
)
logger
.
info
(
"Training/evaluation parameters %s"
,
training_args
)
# Set seed before initializing model.
# Set seed before initializing model.
...
@@ -585,71 +614,51 @@ def main():
...
@@ -585,71 +614,51 @@ def main():
raw_datasets
=
DatasetDict
()
raw_datasets
=
DatasetDict
()
num_workers
=
data_args
.
preprocessing_num_workers
num_workers
=
data_args
.
preprocessing_num_workers
columns_to_keep
=
[
data_args
.
target_audio_column_name
,
data_args
.
prompt_column_name
]
columns_to_keep
=
{
"target_audio_column_name"
:
data_args
.
target_audio_column_name
,
"prompt_column_name"
:
data_args
.
prompt_column_name
}
if
data_args
.
description_column_name
is
not
None
:
if
data_args
.
description_column_name
is
not
None
:
columns_to_keep
.
append
(
data_args
.
description_column_name
)
columns_to_keep
[
"description_column_nam"
]
=
data_args
.
description_column_name
if
data_args
.
conditional_audio_column_name
is
not
None
:
columns_to_keep
.
append
(
data_args
.
conditional_audio_column_name
)
if
training_args
.
do_train
:
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
train_dataset_name
,
data_args
.
train_dataset_name
,
data_args
.
train_dataset_config_name
,
data_args
.
train_dataset_config_name
,
data_args
.
train_metadata_dataset_name
,
metadata_dataset_names
=
data_args
.
train_metadata_dataset_name
,
splits
=
data_args
.
train_split_name
,
splits
=
data_args
.
train_split_name
,
dataset_samples
=
data_args
.
train_dataset_samples
,
dataset_samples
=
data_args
.
train_dataset_samples
,
seed
=
training_args
.
seed
,
seed
=
training_args
.
seed
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
,
columns_to_keep
=
columns_to_keep
.
values
()
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
)
if
data_args
.
target_audio_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--target_audio_column_name '
{
data_args
.
target_audio_column_name
}
' not found in dataset '
{
data_args
.
train_dataset_name
}
'."
" Make sure to set `--target_audio_column_name` to the correct audio column - one of"
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
description_column_name
is
not
None
and
data_args
.
description_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--description_column_name
{
data_args
.
description_column_name
}
not found in dataset '
{
data_args
.
train_dataset_name
}
'. "
"Make sure to set `--description_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
prompt_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--description_column_name
{
data_args
.
prompt_column_name
}
not found in dataset '
{
data_args
.
train_dataset_name
}
'. "
"Make sure to set `--prompt_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
conditional_audio_column_name
is
not
None
and
data_args
.
conditional_audio_column_name
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--conditional_audio_column_name
{
data_args
.
conditional_audio_column_name
}
not found in dataset '
{
data_args
.
train_dataset_name
}
'. "
"Make sure to set `--conditional_audio_column_name` to the correct text column - one of "
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
for
key
in
columns_to_keep
:
if
columns_to_keep
[
key
]
not
in
raw_datasets
[
"train"
].
column_names
:
raise
ValueError
(
f
"--
{
key
}
'
{
columns_to_keep
[
key
]
}
' not found in dataset '
{
data_args
.
train_dataset_name
}
'."
f
" Make sure to set `--
{
key
}
` to the correct audio column - one of"
f
"
{
', '
.
join
(
raw_datasets
[
'train'
].
column_names
)
}
."
)
if
data_args
.
max_train_samples
is
not
None
:
if
data_args
.
max_train_samples
is
not
None
:
raw_datasets
[
"train"
]
=
raw_datasets
[
"train"
].
select
(
range
(
data_args
.
max_train_samples
))
raw_datasets
[
"train"
]
=
raw_datasets
[
"train"
].
select
(
range
(
data_args
.
max_train_samples
))
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
raw_datasets
[
"eval"
]
=
load_multiple_datasets
(
raw_datasets
[
"eval"
]
=
load_multiple_datasets
(
accelerator
,
data_args
.
eval_dataset_name
if
data_args
.
eval_dataset_name
else
data_args
.
train_dataset_name
,
data_args
.
eval_dataset_name
if
data_args
.
eval_dataset_name
else
data_args
.
train_dataset_name
,
data_args
.
eval_dataset_config_name
data_args
.
eval_dataset_config_name
if
data_args
.
eval_dataset_config_name
else
data_args
.
train_dataset_config_name
,
if
data_args
.
eval_dataset_config_name
metadata_dataset_names
=
data_args
.
eval_metadata_dataset_name
,
else
data_args
.
train_dataset_config_name
,
data_args
.
eval_metadata_dataset_name
,
splits
=
data_args
.
eval_split_name
,
splits
=
data_args
.
eval_split_name
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
,
columns_to_keep
=
columns_to_keep
.
values
()
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
)
...
@@ -657,9 +666,8 @@ def main():
...
@@ -657,9 +666,8 @@ def main():
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
# TODO: is is the right way to do ?
# 2. Next, let's load the config as we might need it to create
# 3. Next, let's load the config as we might need it to create
# load config TODO: add the option to create the config from scratch
# load config
config
=
StableSpeechConfig
.
from_pretrained
(
config
=
StableSpeechConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
...
@@ -673,8 +681,7 @@ def main():
...
@@ -673,8 +681,7 @@ def main():
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
model
.
config
.
decoder_start_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
model
.
config
.
decoder_start_token_id
,
})
})
# 3. Now we can instantiate the feature extractor, tokenizers and model
# 4. Now we can instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only
# Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab.
# one local process can concurrently download model & vocab.
...
@@ -692,16 +699,24 @@ def main():
...
@@ -692,16 +699,24 @@ def main():
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
token
=
data_args
.
token
,
token
=
data_args
.
token
,
trust_remote_code
=
data_args
.
trust_remote_code
,
trust_remote_code
=
data_args
.
trust_remote_code
,
use_fast
=
model_args
.
use_fast_tokenizer
,
)
)
# load description tokenizer
# load description tokenizer
description_tokenizer
=
AutoTokenizer
.
from_pretrained
(
description_tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_args
.
description_tokenizer_name
or
model_args
.
model_name_or_path
,
model_args
.
description_tokenizer_name
or
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
token
=
data_args
.
token
,
token
=
data_args
.
token
,
trust_remote_code
=
data_args
.
trust_remote_code
,
trust_remote_code
=
data_args
.
trust_remote_code
,
use_fast
=
model_args
.
use_fast_tokenizer
,
)
)
if
model_args
.
use_fast_tokenizer
:
logger
.
warning
(
"Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
)
prompt_tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
True
description_tokenizer
.
deprecation_warnings
[
"Asking-to-pad-a-fast-tokenizer"
]
=
True
# create model + TODO: not from_pretrained probably
# create model + TODO: not from_pretrained probably
model
=
StableSpeechForConditionalGeneration
.
from_pretrained
(
model
=
StableSpeechForConditionalGeneration
.
from_pretrained
(
model_args
.
model_name_or_path
,
model_args
.
model_name_or_path
,
...
@@ -711,46 +726,31 @@ def main():
...
@@ -711,46 +726,31 @@ def main():
trust_remote_code
=
data_args
.
trust_remote_code
,
trust_remote_code
=
data_args
.
trust_remote_code
,
)
)
# take audio_encoder_feature_extractor
audio_encoder_feature_extractor
=
AutoFeatureExtractor
.
from_pretrained
(
model
.
config
.
audio_encoder
.
_name_or_path
,
)
#
5
. Now we preprocess the datasets including loading the audio, resampling and normalization
#
4
. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
# so that we just need to set the correct target sampling rate and normalize the input
# so that we just need to set the correct target sampling rate and normalize the input
# via the `feature_extractor`
# via the `feature_extractor`
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
data_args
.
target_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
audio_encoder_feature_extractor
.
sampling_rate
)
)
if
data_args
.
conditional_audio_column_name
is
not
None
:
raw_datasets
=
raw_datasets
.
cast_column
(
data_args
.
conditional_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
feature_extractor
.
sampling_rate
)
)
# derive max & min input length for sample rate & max duration
# derive max & min input length for sample rate & max duration
max_target_length
=
data_args
.
max_duration_in_seconds
*
feature_extractor
.
sampling_rate
sampling_rate
=
feature_extractor
.
sampling_rate
min_target_length
=
data_args
.
min_duration_in_seconds
*
feature_extractor
.
sampling_rate
max_target_length
=
data_args
.
max_duration_in_seconds
*
sampling_rate
min_target_length
=
data_args
.
min_duration_in_seconds
*
sampling_rate
target_audio_column_name
=
data_args
.
target_audio_column_name
target_audio_column_name
=
data_args
.
target_audio_column_name
conditional_audio_column_name
=
data_args
.
conditional_audio_column_name
description_column_name
=
data_args
.
description_column_name
description_column_name
=
data_args
.
description_column_name
prompt_column_name
=
data_args
.
prompt_column_name
prompt_column_name
=
data_args
.
prompt_column_name
feature_extractor_input_name
=
feature_extractor
.
model_input_names
[
0
]
feature_extractor_input_name
=
feature_extractor
.
model_input_names
[
0
]
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
target_audio_column_name
,
datasets
.
features
.
Audio
(
sampling_rate
=
sampling_rate
)
)
# Preprocessing the datasets.
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the t
arg
ets.
# We need to read the audio files as arrays and tokenize the te
x
ts.
def
pass_through_processors
(
batch
):
def
pass_through_processors
(
batch
):
# load audio
# load audio
if
conditional_audio_column_name
is
not
None
:
sample
=
batch
[
target_audio_column_name
]
inputs
=
feature_extractor
(
sample
[
"array"
],
sampling_rate
=
sample
[
"sampling_rate"
])
batch
[
feature_extractor_input_name
]
=
getattr
(
inputs
,
feature_extractor_input_name
)[
0
]
if
description_column_name
is
not
None
:
if
description_column_name
is
not
None
:
text
=
batch
[
description_column_name
]
text
=
batch
[
description_column_name
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
.
strip
())[
"input_ids"
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
.
strip
())[
"input_ids"
]
...
@@ -761,14 +761,14 @@ def main():
...
@@ -761,14 +761,14 @@ def main():
# load audio
# load audio
target_sample
=
batch
[
target_audio_column_name
]
target_sample
=
batch
[
target_audio_column_name
]
labels
=
audio_encoder_
feature_extractor
(
target_sample
[
"array"
],
sampling_rate
=
target_sample
[
"sampling_rate"
])
labels
=
feature_extractor
(
target_sample
[
"array"
],
sampling_rate
=
target_sample
[
"sampling_rate"
])
batch
[
"labels"
]
=
labels
[
"input_values"
]
batch
[
"labels"
]
=
labels
[
"input_values"
]
# take length of raw audio waveform
# take length of raw audio waveform
batch
[
"target_length"
]
=
len
(
target_sample
[
"array"
].
squeeze
())
batch
[
"target_length"
]
=
len
(
target_sample
[
"array"
].
squeeze
())
return
batch
return
batch
with
training_args
.
main_process_first
(
desc
=
"dataset map preprocessing"
):
with
accelerator
.
main_process_first
():
vectorized_datasets
=
raw_datasets
.
map
(
vectorized_datasets
=
raw_datasets
.
map
(
pass_through_processors
,
pass_through_processors
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
remove_columns
=
next
(
iter
(
raw_datasets
.
values
())).
column_names
,
...
@@ -785,34 +785,81 @@ def main():
...
@@ -785,34 +785,81 @@ def main():
num_proc
=
num_workers
,
num_proc
=
num_workers
,
input_columns
=
[
"target_length"
],
input_columns
=
[
"target_length"
],
)
)
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
logger
.
info
(
"*** Encode target audio with encodec ***"
)
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder
=
model
.
audio_encoder
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
def
apply_audio_decoder
(
batch
):
def
apply_audio_decoder
(
batch
):
labels
=
audio_decoder
.
encode
(
torch
.
tensor
(
batch
[
"labels"
]).
to
(
audio_decoder
.
device
))[
"audio_codes"
]
len_audio
=
batch
.
pop
(
"len_audio"
)
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
audio_decoder
.
to
(
batch
[
"input_values"
].
device
).
eval
()
model
.
generation_config
.
decoder_start_token_id
,
labels
=
audio_decoder
.
encode
(
**
batch
)[
"audio_codes"
]
model
.
generation_config
.
max_length
+
1
)
output
=
{}
output
[
"len_audio"
]
=
len_audio
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
# (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
output
[
"labels"
]
=
labels
.
squeeze
(
0
).
transpose
(
1
,
2
)
# the first timestamp is associated to a row full of BOS, let's get rid of it
output
[
"ratio"
]
=
torch
.
ones_like
(
len_audio
)
*
labels
.
shape
[
-
1
]
/
len_audio
.
max
()
batch
[
"labels"
]
=
labels
[:,
1
:]
return
output
return
batch
for
split
in
vectorized_datasets
:
with
training_args
.
main_process_first
(
desc
=
"audio target preprocessing"
):
data_loader
=
DataLoader
(
# for now on CPU
vectorized_datasets
[
split
],
# TODO: enrich for GPU
batch_size
=
training_args
.
per_device_eval_batch_size
,
vectorized_datasets
=
vectorized_datasets
.
map
(
collate_fn
=
encoder_data_collator
,
apply_audio_decoder
,
num_workers
=
training_args
.
dataloader_num_workers
,
num_proc
=
num_workers
,
pin_memory
=
True
,
desc
=
"preprocess datasets"
,
)
)
data_loader
=
accelerator
.
prepare
(
data_loader
)
# TODO: will it work on GPU ? unmerged for now https://github.com/huggingface/accelerate/pull/2433
# for split in vectorized_datasets:
all_generated_labels
=
[]
# with distributed_state.split_between_processes(vectorized_datasets[split]["labels"]) as input_labels:
all_ratios
=
[]
# result = audio_decoder(input_labels)
all_lens
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
generate_labels
=
apply_audio_decoder
(
batch
)
generate_labels
=
accelerator
.
pad_across_processes
(
generate_labels
,
dim
=
1
,
pad_index
=
0
)
generate_labels
=
accelerator
.
gather_for_metrics
(
generate_labels
)
all_generated_labels
.
extend
(
generate_labels
[
"labels"
].
cpu
())
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
def
postprocess_dataset
(
sample
,
idx
):
# (1, seq_len, codebooks, bsz)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
model
.
generation_config
.
decoder_start_token_id
,
model
.
generation_config
.
max_length
+
model
.
decoder
.
config
.
num_codebooks
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample
[
"labels"
]
=
labels
[:,
1
:
len_
]
return
sample
# TODO: done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
num_workers
,
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
)
accelerator
.
free_memory
()
del
generate_labels
if
data_args
.
add_audio_samples_to_wandb
and
"wandb"
in
training_args
.
report_to
:
if
data_args
.
add_audio_samples_to_wandb
and
"wandb"
in
training_args
.
report_to
:
if
is_wandb_available
():
if
is_wandb_available
():
...
@@ -827,6 +874,7 @@ def main():
...
@@ -827,6 +874,7 @@ def main():
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# In a second step ``args.preprocessing_only`` can then be set to `False` to load the
# cached dataset
# cached dataset
if
data_args
.
preprocessing_only
:
if
data_args
.
preprocessing_only
:
# TODO: save to disk in this step instead of something else ??
logger
.
info
(
f
"Data preprocessing finished. Files cached at
{
vectorized_datasets
.
cache_files
}
"
)
logger
.
info
(
f
"Data preprocessing finished. Files cached at
{
vectorized_datasets
.
cache_files
}
"
)
return
return
...
@@ -865,9 +913,9 @@ def main():
...
@@ -865,9 +913,9 @@ def main():
# Now save everything to be able to create a single processor later
# Now save everything to be able to create a single processor later
# make sure all processes wait until data is saved
# make sure all processes wait until data is saved
with
training_args
.
main_process_first
():
with
accelerator
.
main_process_first
():
# only the main process saves them
# only the main process saves them
if
is_main_process
(
training_args
.
local_rank
)
:
if
accelerator
.
is_main_process
:
# save feature extractor, tokenizer and config
# save feature extractor, tokenizer and config
if
model_args
.
prompt_tokenizer_name
is
None
and
model_args
.
description_tokenizer_name
or
(
model_args
.
prompt_tokenizer_name
==
model_args
.
description_tokenizer_name
):
if
model_args
.
prompt_tokenizer_name
is
None
and
model_args
.
description_tokenizer_name
or
(
model_args
.
prompt_tokenizer_name
==
model_args
.
description_tokenizer_name
):
prompt_tokenizer
.
save_pretrained
(
training_args
.
output_dir
)
prompt_tokenizer
.
save_pretrained
(
training_args
.
output_dir
)
...
@@ -936,7 +984,7 @@ def main():
...
@@ -936,7 +984,7 @@ def main():
audios
=
predictions
[
"audio"
]
audios
=
predictions
[
"audio"
]
# log the table to wandb
# log the table to wandb
self
.
_wandb
.
log
({
"sample_songs"
:
[
self
.
_wandb
.
Audio
(
audio
,
caption
=
text
,
sample_rate
=
audio_encoder_feature_extractor
.
sampling_rate
)
for
(
audio
,
text
)
in
zip
(
audios
,
texts
)]})
self
.
_wandb
.
log
({
"sample_songs"
:
[
self
.
_wandb
.
Audio
(
audio
,
caption
=
text
,
sample_rate
=
sampling_rate
)
for
(
audio
,
text
)
in
zip
(
audios
,
texts
)]})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment