Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
aa4cbf27
Commit
aa4cbf27
authored
May 14, 2024
by
Yoach Lacombe
Browse files
make style
parent
9271958b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
36 additions
and
29 deletions
+36
-29
parler_tts/configuration_parler_tts.py
parler_tts/configuration_parler_tts.py
+1
-1
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+6
-3
training/arguments.py
training/arguments.py
+7
-3
training/data.py
training/data.py
+2
-1
training/eval.py
training/eval.py
+3
-2
training/run_parler_tts_training.py
training/run_parler_tts_training.py
+15
-18
training/utils.py
training/utils.py
+2
-1
No files found.
parler_tts/configuration_parler_tts.py
View file @
aa4cbf27
...
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
...
...
parler_tts/modeling_parler_tts.py
View file @
aa4cbf27
...
...
@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids
=
outputs
.
sequences
else
:
output_ids
=
outputs
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
...
...
@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
)
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
...
...
@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs
.
sequences
=
output_values
return
outputs
else
:
return
output_values
\ No newline at end of file
return
output_values
training/arguments.py
View file @
aa4cbf27
...
...
@@ -3,6 +3,7 @@ from typing import Optional
from
transformers
import
Seq2SeqTrainingArguments
@
dataclass
class
ModelArguments
:
"""
...
...
@@ -67,15 +68,18 @@ class ModelArguments:
)
asr_model_name_or_path
:
str
=
field
(
default
=
"distil-whisper/distil-large-v2"
,
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
clap_model_name_or_path
:
str
=
field
(
default
=
"laion/larger_clap_music_and_speech"
,
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
@
dataclass
class
DataTrainingArguments
:
"""
...
...
training/data.py
View file @
aa4cbf27
...
...
@@ -11,6 +11,7 @@ from tqdm import tqdm
from
accelerate
import
Accelerator
@
dataclass
class
DataCollatorEncodecWithPadding
:
"""
...
...
@@ -301,4 +302,4 @@ def load_multiple_datasets(
with
accelerator
.
main_process_first
():
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
return
interleaved_dataset
\ No newline at end of file
return
interleaved_dataset
training/eval.py
View file @
aa4cbf27
import
torch
import
torch
import
evaluate
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
...
...
@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
metric
=
evaluate
.
load
(
"wer"
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
...
...
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
)
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
\ No newline at end of file
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
training/run_parler_tts_training.py
View file @
aa4cbf27
...
...
@@ -21,7 +21,6 @@ import os
import
re
import
sys
import
time
from
dataclasses
import
dataclass
,
field
from
datetime
import
timedelta
from
tqdm
import
tqdm
...
...
@@ -38,11 +37,7 @@ from huggingface_hub import HfApi
from
multiprocess
import
set_start_method
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
transformers
import
(
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
)
from
transformers
import
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.optimization
import
get_scheduler
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
...
...
@@ -306,9 +301,7 @@ def main():
# update pad token id and decoder_start_token_id
config
.
update
(
{
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
config
.
decoder_start_token_id
,
...
...
@@ -579,16 +572,18 @@ def main():
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
clap_score
=
clap_similarity
(
model_args
.
clap_model_name_or_path
,
texts
,
audios
,
device
)
results
[
"clap"
]
=
clap_score
word_error
,
transcriptions
=
wer
(
model_args
.
asr_model_name_or_path
,
prompts
,
audios
,
device
,
training_args
.
per_device_eval_batch_size
,
sampling_rate
)
word_error
,
transcriptions
=
wer
(
model_args
.
asr_model_name_or_path
,
prompts
,
audios
,
device
,
training_args
.
per_device_eval_batch_size
,
sampling_rate
,
)
results
[
"wer"
]
=
word_error
return
results
,
texts
,
prompts
,
audios
,
transcriptions
...
...
@@ -878,7 +873,9 @@ def main():
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
if
cur_step
==
total_train_steps
:
# un-wrap student model for save
...
...
@@ -1020,4 +1017,4 @@ def main():
if
__name__
==
"__main__"
:
set_start_method
(
"spawn"
)
main
()
\ No newline at end of file
main
()
training/utils.py
View file @
aa4cbf27
...
...
@@ -8,6 +8,7 @@ from typing import Dict, List
import
torch
from
wandb
import
Audio
def
list_field
(
default
=
None
,
metadata
=
None
):
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
...
...
@@ -121,4 +122,4 @@ def log_pred(
]
},
step
=
step
,
)
\ No newline at end of file
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment