Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
5acad845
Commit
5acad845
authored
Mar 14, 2024
by
Yoach Lacombe
Browse files
add possibility to precompute text hidden states + fix generation
parent
80da6b4c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
36 deletions
+142
-36
run_stable_speech_training.py
run_stable_speech_training.py
+141
-35
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+1
-1
No files found.
run_stable_speech_training.py
View file @
5acad845
...
...
@@ -63,7 +63,7 @@ from transformers.utils.versions import require_version
from
transformers.integrations
import
is_wandb_available
from
transformers
import
AutoConfig
,
AutoModel
from
stable_speech
import
DACConfig
,
DACModel
from
transformers.modeling_outputs
import
BaseModelOutput
AutoConfig
.
register
(
"dac"
,
DACConfig
)
AutoModel
.
register
(
DACConfig
,
DACModel
)
...
...
@@ -464,6 +464,14 @@ class StableSpeechTrainingArguments(Seq2SeqTrainingArguments):
)
},
)
text_encode_per_device_eval_batch_size
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
(
"TODO"
)
},
)
@
dataclass
class
DataCollatorEncodecWithPadding
:
...
...
@@ -531,9 +539,16 @@ class DataCollatorStableSpeechWithPadding:
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=-
100
)
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
{
"labels"
:
labels
,
**
input_ids
}
if
"encoder_outputs"
in
features
[
0
]:
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
)
encoder_hidden_states
=
[
torch
.
tensor
(
feature
[
"encoder_outputs"
][
"last_hidden_state"
])
for
feature
in
features
]
encoder_hidden_states
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
encoder_hidden_states
,
batch_first
=
True
,
padding_value
=
0.
)
batch
=
{
"labels"
:
labels
,
"encoder_outputs"
:
BaseModelOutput
(
last_hidden_state
=
encoder_hidden_states
),
**
input_ids
}
else
:
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
{
"labels"
:
labels
,
**
input_ids
}
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
...
...
@@ -550,7 +565,40 @@ class DataCollatorStableSpeechWithPadding:
batch
[
self
.
feature_extractor_input_name
:
input_values
]
return
batch
@
dataclass
class
T5TextCollatorStableSpeechWithPadding
:
"""
Data collator that will dynamically pad the inputs received.
Args:
description_tokenizer (:class:`~transformers.AutoTokenizer`)
The description_tokenizer used for proccessing the data.
padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
maximum acceptable input length for the model if that argument is not provided.
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
different lengths).
pad_to_multiple_of (:obj:`int`, `optional`):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
"""
description_tokenizer
:
AutoTokenizer
padding
:
Union
[
bool
,
str
]
=
"longest"
pad_to_multiple_of
:
Optional
[
int
]
=
None
def
__call__
(
self
,
features
:
List
[
Dict
[
str
,
Union
[
List
[
int
],
torch
.
Tensor
]]])
->
Dict
[
str
,
torch
.
Tensor
]:
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids_len
=
[
len
(
feature
[
"input_ids"
])
for
feature
in
features
]
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
{
"len_input_ids"
:
torch
.
tensor
(
input_ids_len
),
**
input_ids
}
return
batch
def
convert_dataset_str_to_list
(
dataset_names
,
...
...
@@ -729,6 +777,7 @@ def main():
else
:
mixed_precision
=
"no"
####### A. Preparation
accelerator
=
Accelerator
(
gradient_accumulation_steps
=
training_args
.
gradient_accumulation_steps
,
mixed_precision
=
mixed_precision
,
...
...
@@ -926,6 +975,9 @@ def main():
trust_remote_code
=
data_args
.
trust_remote_code
,
)
# enable gradient checkpointing if necessary
if
training_args
.
gradient_checkpointing
:
model
.
gradient_checkpointing_enable
()
# 4. Now we preprocess the datasets including loading the audio, resampling and normalization
# Thankfully, `datasets` takes care of automatically loading and resampling the audio,
...
...
@@ -947,6 +999,9 @@ def main():
num_codebooks
=
model
.
decoder
.
config
.
num_codebooks
bandwidth
=
model_args
.
bandwidth
# Freeze Encoders
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
if
not
dataset_was_precomputed
:
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
...
...
@@ -996,6 +1051,8 @@ def main():
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
####### B. Encode audio
logger
.
info
(
"*** Encode target audio with encodec ***"
)
# no need to prepare audio_decoder because used for inference without mixed precision
...
...
@@ -1090,6 +1147,52 @@ def main():
accelerator
.
free_memory
()
del
generate_labels
# T5 doesn't support fp16
autocast_kwargs
=
AutocastKwargs
(
enabled
=
(
mixed_precision
!=
"fp16"
))
####### C. Encode text if text encoder is freezed
if
model_args
.
freeze_text_encoder
:
text_data_collator
=
T5TextCollatorStableSpeechWithPadding
(
description_tokenizer
,
pad_to_multiple_of
=
data_args
.
pad_to_multiple_of
)
for
split
in
vectorized_datasets
:
data_loader
=
DataLoader
(
vectorized_datasets
[
split
],
batch_size
=
training_args
.
text_encode_per_device_eval_batch_size
,
collate_fn
=
text_data_collator
,
num_workers
=
training_args
.
dataloader_num_workers
,
pin_memory
=
True
,
)
data_loader
=
accelerator
.
prepare
(
data_loader
)
all_encoder_outputs
=
[]
all_encoder_lengths
=
[]
for
batch
in
tqdm
(
data_loader
,
disable
=
not
accelerator
.
is_local_main_process
):
model
.
text_encoder
.
to
(
batch
[
"input_ids"
].
device
)
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
encoder_outputs
=
model
.
text_encoder
(
input_ids
=
batch
[
"input_ids"
],
attention_mask
=
batch
[
"attention_mask"
])
encoder_outputs
=
accelerator
.
pad_across_processes
(
encoder_outputs
,
dim
=
1
,
pad_index
=
prompt_tokenizer
.
pad_token_id
)
encoder_outputs
=
accelerator
.
gather_for_metrics
(
encoder_outputs
)
lengths
=
accelerator
.
gather_for_metrics
(
batch
[
"len_input_ids"
])
# TODO: check it works multi device
all_encoder_outputs
.
extend
(
encoder_outputs
.
last_hidden_state
.
to
(
"cpu"
))
all_encoder_lengths
.
extend
(
lengths
.
to
(
"cpu"
))
def
postprocess_dataset
(
input_ids
,
idx
):
output
=
{
"encoder_outputs"
:
BaseModelOutput
(
last_hidden_state
=
all_encoder_outputs
[
idx
][:
all_encoder_lengths
[
idx
]])}
return
output
# TODO: done multiple times, how to deal with it.
with
accelerator
.
main_process_first
():
vectorized_datasets
[
split
]
=
vectorized_datasets
[
split
].
map
(
postprocess_dataset
,
num_proc
=
1
,
# this one is resource consuming if many processor.
input_columns
=
[
"input_ids"
],
desc
=
"Postprocessing labeling"
,
with_indices
=
True
,
writer_batch_size
=
100
,
)
if
data_args
.
save_to_disk
is
not
None
and
not
dataset_was_precomputed
:
...
...
@@ -1110,11 +1213,6 @@ def main():
# 6. Next, we can prepare the training.
# enable gradient checkpointing if necessary
if
training_args
.
gradient_checkpointing
:
model
.
gradient_checkpointing_enable
()
# Let's use word CLAP similary and WER metrics as our evaluation metrics,
# Define evaluation metrics during training, *i.e.* CLAP similarity TODO: allow using another CLAP
...
...
@@ -1126,14 +1224,15 @@ def main():
def
clap_similarity
(
texts
,
audios
,
device
):
clap_inputs
=
clap_processor
(
text
=
texts
,
audios
=
audios
,
padding
=
True
,
return_tensors
=
"pt"
).
to
(
device
)
clap
.
to
(
device
)
text_features
=
clap
.
get_text_features
(
clap_inputs
[
"input_ids"
],
attention_mask
=
clap_inputs
.
get
(
"attention_mask"
,
None
))
audio_features
=
clap
.
get_audio_features
(
clap_inputs
[
"input_features"
])
cosine_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
audio_features
,
text_features
,
dim
=
1
,
eps
=
1e-8
)
with
torch
.
no_grad
():
text_features
=
clap
.
get_text_features
(
clap_inputs
[
"input_ids"
],
attention_mask
=
clap_inputs
.
get
(
"attention_mask"
,
None
))
audio_features
=
clap
.
get_audio_features
(
clap_inputs
[
"input_features"
])
cosine_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
audio_features
,
text_features
,
dim
=
1
,
eps
=
1e-8
)
clap
.
to
(
"cpu"
)
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
()
return
cosine_sim
.
mean
()
.
to
(
"cpu"
)
def
wer
(
prompts
,
audios
,
device
):
asr_pipeline
=
pipeline
(
model
=
"distil-whisper/distil-large-v2"
,
device
=
device
)
...
...
@@ -1186,6 +1285,9 @@ def main():
else
:
eval_steps
=
training_args
.
eval_steps
# T5 doesn't support fp16
autocast_kwargs
=
AutocastKwargs
(
enabled
=
(
mixed_precision
!=
"fp16"
))
# Define optimizer, LR scheduler, collator
optimizer
=
torch
.
optim
.
AdamW
(
params
=
model
.
parameters
(),
...
...
@@ -1208,8 +1310,6 @@ def main():
audio_feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
,
pad_to_multiple_of
=
data_args
.
pad_to_multiple_of
)
# Freeze Encoders
model
.
freeze_encoders
(
model_args
.
freeze_text_encoder
)
# Prepare everything with accelerate
model
,
optimizer
,
lr_scheduler
=
accelerator
.
prepare
(
model
,
optimizer
,
lr_scheduler
)
...
...
@@ -1318,10 +1418,13 @@ def main():
):
model
.
train
()
if
mixed_precision
==
"fp16"
:
if
mixed_precision
==
"fp16"
and
not
model_args
.
freeze_text_encoder
:
# fp16 doesn't work with T5-like models
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
encoder_outputs
=
model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
if
training_args
.
parallel_mode
.
value
!=
"distributed"
:
encoder_outputs
=
model
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
else
:
encoder_outputs
=
model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
batch
[
"encoder_outputs"
]
=
encoder_outputs
outputs
=
model
(
**
batch
)
...
...
@@ -1335,10 +1438,13 @@ def main():
# Define eval fn
def
eval_step
(
batch
,
accelerator
,
autocast_kwargs
,):
model
.
eval
()
if
mixed_precision
==
"fp16"
:
if
mixed_precision
==
"fp16"
and
not
model_args
.
freeze_text_encoder
:
# fp16 doesn't work with T5-like models
with
accelerator
.
autocast
(
autocast_handler
=
autocast_kwargs
):
encoder_outputs
=
model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
if
training_args
.
parallel_mode
.
value
!=
"distributed"
:
encoder_outputs
=
model
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
else
:
encoder_outputs
=
model
.
module
.
text_encoder
(
input_ids
=
batch
.
get
(
"input_ids"
),
attention_mask
=
batch
.
get
(
"attention_mask"
,
None
))
batch
[
"encoder_outputs"
]
=
encoder_outputs
with
torch
.
no_grad
():
...
...
@@ -1354,7 +1460,6 @@ def main():
output_audios
=
accelerator
.
pad_across_processes
(
output_audios
,
dim
=
1
,
pad_index
=
0
)
return
output_audios
autocast_kwargs
=
AutocastKwargs
(
enabled
=
False
)
for
epoch
in
range
(
epochs_trained
,
num_epochs
):
vectorized_datasets
[
"train"
]
=
vectorized_datasets
[
"train"
].
shuffle
(
training_args
.
seed
)
# TODO: add args
...
...
@@ -1471,9 +1576,9 @@ def main():
# TODO: better gather
generated_audios
,
input_ids
,
prompts
=
accelerator
.
pad_across_processes
((
generated_audios
,
batch
[
"input_ids"
],
batch
[
"prompt_input_ids"
]),
dim
=
1
,
pad_index
=
0
)
generated_audios
,
input_ids
,
prompts
=
accelerator
.
gather_for_metrics
((
generated_audios
,
input_ids
,
prompts
))
eval_preds
.
extend
(
generated_audios
)
eval_descriptions
.
extend
(
input_ids
)
eval_prompts
.
extend
(
prompts
)
eval_preds
.
extend
(
generated_audios
.
to
(
"cpu"
)
)
eval_descriptions
.
extend
(
input_ids
.
to
(
"cpu"
)
)
eval_prompts
.
extend
(
prompts
.
to
(
"cpu"
)
)
eval_time
=
time
.
time
()
-
eval_start
# normalize eval metrics
...
...
@@ -1489,16 +1594,17 @@ def main():
)
eval_metrics
.
update
(
metric_values
)
metrics_desc
=
" "
.
join
([
f
"Eval
{
key
}
:
{
value
}
|"
for
key
,
value
in
metric_values
.
items
()])
log_pred
(
accelerator
,
pred_descriptions
,
pred_prompts
,
transcriptions
,
audios
,
sampling_rate
=
sampling_rate
,
step
=
cur_step
,
prefix
=
"eval"
,
)
if
"wandb"
in
training_args
.
report_to
:
log_pred
(
accelerator
,
pred_descriptions
,
pred_prompts
,
transcriptions
,
audios
,
sampling_rate
=
sampling_rate
,
step
=
cur_step
,
prefix
=
"eval"
,
)
# Print metrics and update progress bar
steps_trained_progress_bar
.
write
(
...
...
stable_speech/modeling_stable_speech.py
View file @
5acad845
...
...
@@ -2617,7 +2617,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values
=
[]
for
sample_id
in
range
(
batch_size
):
sample
=
output_ids
[:,
sample_id
]
sample_mask
=
((
(
sample
=
=
generation_config
.
bos_token_id
)
|
(
sample
==
generation_config
.
eos_token_id
)
|
(
sample
==
generation_config
.
pad_token_id
)
).
sum
(
dim
=
(
0
,
1
))
==
0
)
sample_mask
=
((
sample
>
=
self
.
audio_encoder
.
config
.
codebook_size
).
sum
(
dim
=
(
0
,
1
))
==
0
)
if
sample_mask
.
sum
()
>
0
:
sample
=
sample
[:,
:,
sample_mask
]
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment