Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
eaf7947b
Commit
eaf7947b
authored
Apr 24, 2024
by
Dan Lyth
Browse files
small train.py updates
parent
3170ac02
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
48 deletions
+16
-48
training/train.py
training/train.py
+16
-48
No files found.
training/train.py
View file @
eaf7947b
...
@@ -24,8 +24,6 @@ import time
...
@@ -24,8 +24,6 @@ import time
from
multiprocess
import
set_start_method
from
multiprocess
import
set_start_method
from
datetime
import
timedelta
from
datetime
import
timedelta
import
evaluate
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
pathlib
import
Path
from
pathlib
import
Path
...
@@ -33,22 +31,18 @@ import datasets
...
@@ -33,22 +31,18 @@ import datasets
import
torch
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
datasets
import
DatasetDict
,
Dataset
,
IterableDataset
,
concatenate_datasets
from
datasets
import
IterableDataset
from
huggingface_hub
import
Repository
,
create_repo
from
huggingface_hub
import
Repository
,
create_repo
import
transformers
import
transformers
from
transformers
import
(
from
transformers
import
(
AutoFeatureExtractor
,
AutoFeatureExtractor
,
AutoModel
,
AutoProcessor
,
AutoTokenizer
,
AutoTokenizer
,
HfArgumentParser
HfArgumentParser
,
)
)
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers
import
pipeline
from
transformers.optimization
import
get_scheduler
from
transformers.optimization
import
get_scheduler
from
transformers.utils
import
send_example_telemetry
from
transformers.utils
import
send_example_telemetry
from
transformers
import
AutoModel
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
...
@@ -57,13 +51,13 @@ from accelerate.utils.memory import release_memory
...
@@ -57,13 +51,13 @@ from accelerate.utils.memory import release_memory
from
parler_tts
import
(
from
parler_tts
import
(
ParlerTTSForConditionalGeneration
,
ParlerTTSForConditionalGeneration
,
ParlerTTSConfig
,
ParlerTTSConfig
build_delay_pattern_mask
,
)
)
from
parler_tts.utils
import
get_last_checkpoint
,
rotate_checkpoints
,
log_pred
,
log_metric
from
parler_tts.utils
import
get_last_checkpoint
,
rotate_checkpoints
,
log_pred
,
log_metric
from
parler_tts.arguments
import
ModelArguments
,
DataTrainingArguments
,
ParlerTTSTrainingArguments
from
parler_tts.arguments
import
ModelArguments
,
DataTrainingArguments
,
ParlerTTSTrainingArguments
from
parler_tts.data
import
DataCollatorParlerTTSWithPadding
,
DataCollatorEncodecWithPadding
from
parler_tts.data
import
DataCollatorParlerTTSWithPadding
from
parler_tts.eval
import
clap_similarity
,
wer
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -271,47 +265,22 @@ def main():
...
@@ -271,47 +265,22 @@ def main():
# Let's use word CLAP similary and WER metrics as our evaluation metrics # TODO move this to seperate file
# Let's use word CLAP similary and WER metrics as our evaluation metrics # TODO move this to seperate file
# Define evaluation metrics during training, *i.e.* CLAP similarity
# Define evaluation metrics during training, *i.e.* CLAP similarity
clap
=
AutoModel
.
from_pretrained
(
model_args
.
clap_model_name_or_path
)
clap_processor
=
AutoProcessor
.
from_pretrained
(
model_args
.
clap_model_name_or_path
)
metric
=
evaluate
.
load
(
"wer"
)
def
clap_similarity
(
texts
,
audios
,
device
):
clap_inputs
=
clap_processor
(
text
=
texts
,
audios
=
audios
,
padding
=
True
,
return_tensors
=
"pt"
).
to
(
device
)
clap
.
to
(
device
)
with
torch
.
no_grad
():
text_features
=
clap
.
get_text_features
(
clap_inputs
[
"input_ids"
],
attention_mask
=
clap_inputs
.
get
(
"attention_mask"
,
None
)
)
audio_features
=
clap
.
get_audio_features
(
clap_inputs
[
"input_features"
])
cosine_sim
=
torch
.
nn
.
functional
.
cosine_similarity
(
audio_features
,
text_features
,
dim
=
1
,
eps
=
1e-8
)
clap
.
to
(
"cpu"
)
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
def
wer
(
prompts
,
audios
,
device
):
asr_pipeline
=
pipeline
(
model
=
model_args
.
asr_model_name_or_path
,
device
=
device
)
transcriptions
=
asr_pipeline
(
[{
"raw"
:
audio
,
"sampling_rate"
:
sampling_rate
}
for
audio
in
audios
],
batch_size
=
int
(
training_args
.
per_device_eval_batch_size
),
)
word_error
=
100
*
metric
.
compute
(
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
)
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
eval_methods
=
{
"clap"
:
clap_similarity
,
"wer"
:
wer
}
def
compute_metrics
(
audios
,
descriptions
,
prompts
,
device
=
"cpu"
):
def
compute_metrics
(
audios
,
descriptions
,
prompts
,
device
=
"cpu"
):
results
=
{}
input_ids
=
descriptions
input_ids
=
descriptions
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
results
=
{
"clap"
:
eval_methods
[
"clap"
](
texts
,
audios
,
device
)}
word_error
,
transcriptions
=
eval_methods
[
"wer"
](
prompts
,
audios
,
device
)
clap_score
=
clap_similarity
(
model_args
.
clap_model_name_or_path
,
texts
,
audios
,
device
)
results
[
"clap"
]
=
clap_score
word_error
,
transcriptions
=
wer
(
model_args
.
asr_model_name_or_path
,
prompts
,
audios
,
device
,
training_args
.
per_device_eval_batch_size
,
sampling_rate
)
results
[
"wer"
]
=
word_error
results
[
"wer"
]
=
word_error
return
results
,
texts
,
prompts
,
audios
,
transcriptions
return
results
,
texts
,
prompts
,
audios
,
transcriptions
...
@@ -564,7 +533,6 @@ def main():
...
@@ -564,7 +533,6 @@ def main():
resume_step
=
None
resume_step
=
None
for
batch
in
train_dataloader
:
for
batch
in
train_dataloader
:
breakpoint
()
with
accelerator
.
accumulate
(
model
):
with
accelerator
.
accumulate
(
model
):
loss
,
train_metric
=
train_step
(
batch
,
accelerator
,
autocast_kwargs
)
loss
,
train_metric
=
train_step
(
batch
,
accelerator
,
autocast_kwargs
)
accelerator
.
backward
(
loss
)
accelerator
.
backward
(
loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment