Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
43087d4a
Commit
43087d4a
authored
Apr 08, 2024
by
Yoach Lacombe
Browse files
clean training script
parent
c734f3ec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
51 deletions
+29
-51
run_parler_tts_training.py
run_parler_tts_training.py
+29
-51
No files found.
run_parler_tts_training.py
View file @
43087d4a
...
...
@@ -52,9 +52,7 @@ from transformers import (
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers
import
pipeline
from
transformers.optimization
import
get_scheduler
from
transformers.utils
import
check_min_version
,
send_example_telemetry
from
transformers.utils.versions
import
require_version
from
transformers.integrations
import
is_wandb_available
from
transformers.utils
import
send_example_telemetry
from
transformers
import
AutoModel
...
...
@@ -68,13 +66,7 @@ from parler_tts import (
build_delay_pattern_mask
,
)
if
is_wandb_available
():
from
wandb
import
Audio
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version
(
"4.38.0.dev0"
)
require_version
(
"datasets>=1.18.0"
,
"To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt"
)
from
wandb
import
Audio
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -202,7 +194,6 @@ class ModelArguments:
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
# TODO: pretrain from scratch
model_name_or_path
:
str
=
field
(
metadata
=
{
"help"
:
"Path to pretrained model or model identifier from huggingface.co/models"
}
)
...
...
@@ -256,9 +247,18 @@ class ModelArguments:
metadata
=
{
"help"
:
"Generation max length."
},
)
bandwidth
:
float
=
field
(
default
=
6
,
# TODO
default
=
6
,
metadata
=
{
"help"
:
"Audio encoder bandwidth."
},
)
asr_model_name_or_path
:
str
=
field
(
default
=
"distil-whisper/distil-large-v2"
,
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
)
clap_model_name_or_path
:
str
=
field
(
default
=
"laion/larger_clap_music_and_speech"
,
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
)
@
dataclass
...
...
@@ -333,17 +333,17 @@ class DataTrainingArguments:
" librispeech and common voice, set `train_dataset_name='librispeech_asr+common_voice'`."
},
)
target_audio_column_name
:
str
=
field
(
# TODO
target_audio_column_name
:
str
=
field
(
default
=
"audio"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the target audio data. Defaults to 'audio'"
},
)
description_column_name
:
str
=
field
(
# TODO
description_column_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset column containing the text data. Defaults to 'None'."
},
metadata
=
{
"help"
:
"The name of the dataset column containing the
description
text data. Defaults to 'None'."
},
)
prompt_column_name
:
str
=
field
(
# TODO
prompt_column_name
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The name of the dataset column containing the text data. Defaults to 'None'."
},
metadata
=
{
"help"
:
"The name of the dataset column containing the
prompt
text data. Defaults to 'None'."
},
)
overwrite_cache
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"Overwrite the cached preprocessed datasets or not."
}
...
...
@@ -482,9 +482,9 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments):
)
},
)
audio_encode_per_device_
eval_
batch_size
:
int
=
field
(
audio_encode
r
_per_device_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
(
"
TODO
"
)},
metadata
=
{
"help"
:
(
"
Specify the batch size of the audio encoding pre-processing steps.
"
)},
)
...
...
@@ -521,8 +521,6 @@ class DataCollatorParlerTTSWithPadding:
The prompt_tokenizer used for proccessing the data.
description_tokenizer (:class:`~transformers.AutoTokenizer`)
The description_tokenizer used for proccessing the data.
audio_feature_extractor (:class:`~transformers.AutoFeatureExtractor`)
The audio_feature_extractor used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
...
...
@@ -540,8 +538,6 @@ class DataCollatorParlerTTSWithPadding:
prompt_tokenizer
:
AutoTokenizer
description_tokenizer
:
AutoTokenizer
audio_feature_extractor
:
AutoFeatureExtractor
feature_extractor_input_name
:
Optional
[
str
]
=
"input_values"
padding
:
Union
[
bool
,
str
]
=
"longest"
pad_to_multiple_of
:
Optional
[
int
]
=
None
prompt_max_length
:
Optional
[
int
]
=
None
...
...
@@ -588,15 +584,6 @@ class DataCollatorParlerTTSWithPadding:
if
"attention_mask"
in
prompt_input_ids
:
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
if
self
.
feature_extractor_input_name
in
features
[
0
]:
# TODO (YL): verify that it works - IMPORTANT -> probably not working
input_values
=
[
{
self
.
feature_extractor_input_name
:
feature
[
self
.
feature_extractor_input_name
]}
for
feature
in
features
]
input_values
=
self
.
feature_extractor
.
pad
(
input_values
,
return_tensors
=
"pt"
)
batch
[
self
.
feature_extractor_input_name
:
input_values
]
return
batch
...
...
@@ -1019,7 +1006,6 @@ def main():
)
# 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch
config
=
ParlerTTSConfig
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
...
...
@@ -1028,7 +1014,6 @@ def main():
)
# update pad token id and decoder_start_token_id
# TODO(YL): verify if this makes sense, maybe should do it for model.decoder
config
.
update
(
{
"pad_token_id"
:
model_args
.
pad_token_id
...
...
@@ -1040,7 +1025,7 @@ def main():
}
)
# create model
+ TODO(YL): not from_pretrained probably
# create model
model
=
ParlerTTSForConditionalGeneration
.
from_pretrained
(
model_args
.
model_name_or_path
,
cache_dir
=
model_args
.
cache_dir
,
...
...
@@ -1076,7 +1061,6 @@ def main():
# Freeze Encoders
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
# TODO: remove when releasing
# Test all gather - used for warmout and avoiding timeout
test_tensor
=
torch
.
tensor
([
accelerator
.
process_index
],
device
=
accelerator
.
device
)
gathered_tensor
=
accelerator
.
gather
(
test_tensor
)
...
...
@@ -1100,7 +1084,6 @@ def main():
batch
=
{}
batch
[
"input_ids"
]
=
description_tokenizer
(
description
.
strip
())[
"input_ids"
]
# TODO: add possibility to train without description column
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
prompt
.
strip
())[
"input_ids"
]
return
batch
...
...
@@ -1154,7 +1137,7 @@ def main():
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
raw_datasets
[
split
],
batch_size
=
training_args
.
audio_encode_per_device_
eval_
batch_size
,
batch_size
=
training_args
.
audio_encode
r
_per_device_batch_size
,
collate_fn
=
encoder_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
pin_memory
=
True
,
...
...
@@ -1221,7 +1204,6 @@ def main():
output
=
{
"labels"
:
labels
[:,
1
:]}
return
output
# TODO(YL): done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
...
...
@@ -1302,9 +1284,9 @@ def main():
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity
TODO: allow using another CLAP
clap
=
AutoModel
.
from_pretrained
(
"laion/larger_clap_music_and_speech"
)
clap_processor
=
AutoProcessor
.
from_pretrained
(
"laion/larger_clap_music_and_speech"
)
# Define evaluation metrics during training, *i.e.* CLAP similarity
clap
=
AutoModel
.
from_pretrained
(
model_args
.
clap_model_name_or_path
)
clap_processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
clap_model_name_or_path
)
metric
=
evaluate
.
load
(
"wer"
)
def
clap_similarity
(
texts
,
audios
,
device
):
...
...
@@ -1323,7 +1305,7 @@ def main():
return
cosine_sim
.
mean
().
to
(
"cpu"
)
def
wer
(
prompts
,
audios
,
device
):
asr_pipeline
=
pipeline
(
model
=
"distil-whisper/distil-large-v2"
,
device
=
device
)
asr_pipeline
=
pipeline
(
model
=
model_args
.
asr_model_name_or_path
,
device
=
device
)
transcriptions
=
asr_pipeline
(
[{
"raw"
:
audio
,
"sampling_rate"
:
sampling_rate
}
for
audio
in
audios
],
batch_size
=
int
(
training_args
.
per_device_eval_batch_size
),
...
...
@@ -1394,8 +1376,6 @@ def main():
# Instantiate custom data collator
data_collator
=
DataCollatorParlerTTSWithPadding
(
audio_feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
,
pad_to_multiple_of
=
data_args
.
pad_to_multiple_of
,
...
...
@@ -1531,7 +1511,6 @@ def main():
outputs
=
model
(
**
batch
)
# CE (data) loss
ce_loss
=
outputs
.
loss
# TODO: add CE per codebook
metrics
=
{
"loss"
:
ce_loss
}
return
ce_loss
,
metrics
...
...
@@ -1578,8 +1557,9 @@ def main():
for
epoch
in
range
(
epochs_trained
,
num_epochs
):
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
# TODO(YL): add args
sampler
=
LengthGroupedSampler
(
train_batch_size
,
lengths
=
vectorized_datasets
[
"train"
][
"target_length"
])
sampler
=
None
if
training_args
.
group_by_length
:
sampler
=
LengthGroupedSampler
(
train_batch_size
,
lengths
=
vectorized_datasets
[
"train"
][
"target_length"
])
train_dataloader
=
DataLoader
(
vectorized_datasets
[
"train"
],
collate_fn
=
data_collator
,
...
...
@@ -1631,7 +1611,7 @@ def main():
# save checkpoint and weights after each save_steps and at the end of training
if
(
cur_step
%
training_args
.
save_steps
==
0
)
or
cur_step
==
total_train_steps
:
intermediate_dir
=
os
.
path
.
join
(
training_args
.
output_dir
,
f
"checkpoint-
{
cur_step
}
-epoch-
{
epoch
}
"
)
# safe_serialization=False to avoid shared tensors saving issue (TODO: it's a temporary fix)
# safe_serialization=False to avoid shared tensors saving issue (TODO
(YL)
: it's a temporary fix)
# https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
wait_for_everyone
()
...
...
@@ -1701,8 +1681,6 @@ def main():
):
generated_audios
=
generate_step
(
batch
)
# Gather all predictions and targets
# TODO: also add prompt ids
# TODO: better gather
generated_audios
,
input_ids
,
prompts
=
accelerator
.
pad_across_processes
(
(
generated_audios
,
batch
[
"input_ids"
],
batch
[
"prompt_input_ids"
]),
dim
=
1
,
pad_index
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment