Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
7ae1e8e3
Commit
7ae1e8e3
authored
Feb 27, 2024
by
Yoach Lacombe
Browse files
remove trainer code
parent
9d25447e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
135 deletions
+0
-135
run_stable_speech_training.py
run_stable_speech_training.py
+0
-135
No files found.
run_stable_speech_training.py
View file @
7ae1e8e3
...
...
@@ -1455,141 +1455,6 @@ def main():
accelerator
.
end_training
()
###########################################################################
# Initialize StableSpeechTrainer
trainer
=
StableSpeechTrainer
(
model
=
model
,
data_collator
=
data_collator
,
args
=
training_args
,
compute_metrics
=
compute_metrics
,
train_dataset
=
vectorized_datasets
[
"train"
]
if
training_args
.
do_train
else
None
,
eval_dataset
=
vectorized_datasets
[
"eval"
]
if
training_args
.
do_eval
else
None
,
tokenizer
=
feature_extractor
,
)
if
data_args
.
add_audio_samples_to_wandb
and
"wandb"
in
training_args
.
report_to
and
training_args
.
do_eval
:
max_eval_samples
=
(
data_args
.
max_eval_samples
if
data_args
.
max_eval_samples
is
not
None
else
len
(
vectorized_datasets
[
"eval"
])
)
def
decode_predictions
(
predictions
):
audios
=
predictions
.
predictions
return
{
"audio"
:
np
.
array
(
audios
)}
class
WandbPredictionProgressCallback
(
WandbCallback
):
"""Custom WandbCallback to log model predictions during training.
"""
def
__init__
(
self
,
trainer
,
val_dataset
,
description_tokenizer
,
# TODO: add
num_samples
=
8
):
"""Initializes the WandbPredictionProgressCallback instance.
Args:
trainer (Seq2SeqTrainer): The Hugging Face Seq2SeqTrainer instance.
val_dataset (Dataset): The validation dataset.
num_samples (int, optional): Number of samples to select from
the validation dataset for generating predictions.
Defaults to 8.
"""
super
().
__init__
()
self
.
trainer
=
trainer
self
.
description_tokenizer
=
description_tokenizer
self
.
sample_dataset
=
val_dataset
.
select
(
range
(
num_samples
))
def
on_evaluate
(
self
,
args
,
state
,
control
,
**
kwargs
):
super
().
on_evaluate
(
args
,
state
,
control
,
**
kwargs
)
predictions
=
self
.
trainer
.
predict
(
self
.
sample_dataset
)
# decode predictions and labels
predictions
=
decode_predictions
(
predictions
)
input_ids
=
self
.
sample_dataset
[
"input_ids"
]
texts
=
self
.
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
audios
=
predictions
[
"audio"
]
# log the table to wandb
self
.
_wandb
.
log
({
"sample_songs"
:
[
self
.
_wandb
.
Audio
(
audio
,
caption
=
text
,
sample_rate
=
sampling_rate
)
for
(
audio
,
text
)
in
zip
(
audios
,
texts
)]})
# Instantiate the WandbPredictionProgressCallback
progress_callback
=
WandbPredictionProgressCallback
(
trainer
=
trainer
,
val_dataset
=
vectorized_datasets
[
"eval"
],
description_tokenizer
=
description_tokenizer
,
num_samples
=
max_eval_samples
,
)
# Add the callback to the trainer
trainer
.
add_callback
(
progress_callback
)
# 8. Finally, we can start training
# Training
if
training_args
.
do_train
:
# use last checkpoint if exist
if
last_checkpoint
is
not
None
:
checkpoint
=
last_checkpoint
# TODO: it's loading trainer from model_name_or_path doesn't work if saving config
# elif os.path.isdir(model_args.model_name_or_path):
# checkpoint = model_args.model_name_or_path
else
:
checkpoint
=
None
train_result
=
trainer
.
train
(
resume_from_checkpoint
=
checkpoint
,
ignore_keys_for_eval
=
[
"past_key_values"
,
"attentions"
])
trainer
.
save_model
()
metrics
=
train_result
.
metrics
max_train_samples
=
(
data_args
.
max_train_samples
if
data_args
.
max_train_samples
is
not
None
else
len
(
vectorized_datasets
[
"train"
])
)
metrics
[
"train_samples"
]
=
min
(
max_train_samples
,
len
(
vectorized_datasets
[
"train"
]))
trainer
.
log_metrics
(
"train"
,
metrics
)
trainer
.
save_metrics
(
"train"
,
metrics
)
trainer
.
save_state
()
# Evaluation
results
=
{}
if
training_args
.
do_eval
:
logger
.
info
(
"*** Evaluate ***"
)
metrics
=
trainer
.
evaluate
()
max_eval_samples
=
(
data_args
.
max_eval_samples
if
data_args
.
max_eval_samples
is
not
None
else
len
(
vectorized_datasets
[
"eval"
])
)
metrics
[
"eval_samples"
]
=
min
(
max_eval_samples
,
len
(
vectorized_datasets
[
"eval"
]))
trainer
.
log_metrics
(
"eval"
,
metrics
)
trainer
.
save_metrics
(
"eval"
,
metrics
)
# Write model card and (optionally) push to hub
config_name
=
data_args
.
train_dataset_config_name
if
data_args
.
train_dataset_config_name
is
not
None
else
"na"
kwargs
=
{
"finetuned_from"
:
model_args
.
model_name_or_path
,
"tasks"
:
"text-to-speech"
,
"tags"
:
[
"text-to-speech"
,
data_args
.
train_dataset_name
],
"dataset_args"
:
(
f
"Config:
{
config_name
}
, Training split:
{
data_args
.
train_split_name
}
, Eval split:"
f
"
{
data_args
.
eval_split_name
}
"
),
"dataset"
:
f
"
{
data_args
.
train_dataset_name
.
upper
()
}
-
{
config_name
.
upper
()
}
"
,
}
if
training_args
.
push_to_hub
:
trainer
.
push_to_hub
(
**
kwargs
)
else
:
trainer
.
create_model_card
(
**
kwargs
)
return
results
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment