Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
aa4cbf27
"vscode:/vscode.git/clone" did not exist on "7cf5d8f77857e1cc64e585f46e2f656ea4eef8ec"
Commit
aa4cbf27
authored
May 14, 2024
by
Yoach Lacombe
Browse files
make style
parent
9271958b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
36 additions
and
29 deletions
+36
-29
parler_tts/configuration_parler_tts.py
parler_tts/configuration_parler_tts.py
+1
-1
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+6
-3
training/arguments.py
training/arguments.py
+7
-3
training/data.py
training/data.py
+2
-1
training/eval.py
training/eval.py
+3
-2
training/run_parler_tts_training.py
training/run_parler_tts_training.py
+15
-18
training/utils.py
training/utils.py
+2
-1
No files found.
parler_tts/configuration_parler_tts.py
View file @
aa4cbf27
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to 2049):
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
num_hidden_layers (`int`, *optional*, defaults to 24):
...
...
parler_tts/modeling_parler_tts.py
View file @
aa4cbf27
...
@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids
=
outputs
.
sequences
output_ids
=
outputs
.
sequences
else
:
else
:
output_ids
=
outputs
output_ids
=
outputs
# apply the pattern mask to the final ids
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
...
@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
)
)
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
...
@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs
.
sequences
=
output_values
outputs
.
sequences
=
output_values
return
outputs
return
outputs
else
:
else
:
return
output_values
return
output_values
\ No newline at end of file
training/arguments.py
View file @
aa4cbf27
...
@@ -3,6 +3,7 @@ from typing import Optional
...
@@ -3,6 +3,7 @@ from typing import Optional
from
transformers
import
Seq2SeqTrainingArguments
from
transformers
import
Seq2SeqTrainingArguments
@
dataclass
@
dataclass
class
ModelArguments
:
class
ModelArguments
:
"""
"""
...
@@ -67,15 +68,18 @@ class ModelArguments:
...
@@ -67,15 +68,18 @@ class ModelArguments:
)
)
asr_model_name_or_path
:
str
=
field
(
asr_model_name_or_path
:
str
=
field
(
default
=
"distil-whisper/distil-large-v2"
,
default
=
"distil-whisper/distil-large-v2"
,
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
)
clap_model_name_or_path
:
str
=
field
(
clap_model_name_or_path
:
str
=
field
(
default
=
"laion/larger_clap_music_and_speech"
,
default
=
"laion/larger_clap_music_and_speech"
,
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
)
@
dataclass
@
dataclass
class
DataTrainingArguments
:
class
DataTrainingArguments
:
"""
"""
...
...
training/data.py
View file @
aa4cbf27
...
@@ -11,6 +11,7 @@ from tqdm import tqdm
...
@@ -11,6 +11,7 @@ from tqdm import tqdm
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
@
dataclass
@
dataclass
class
DataCollatorEncodecWithPadding
:
class
DataCollatorEncodecWithPadding
:
"""
"""
...
@@ -301,4 +302,4 @@ def load_multiple_datasets(
...
@@ -301,4 +302,4 @@ def load_multiple_datasets(
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
return
interleaved_dataset
return
interleaved_dataset
\ No newline at end of file
training/eval.py
View file @
aa4cbf27
import
torch
import
torch
import
evaluate
import
evaluate
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
...
@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
...
@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs
.
to
(
"cpu"
)
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
metric
=
evaluate
.
load
(
"wer"
)
metric
=
evaluate
.
load
(
"wer"
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
...
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
...
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
)
)
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
\ No newline at end of file
training/run_parler_tts_training.py
View file @
aa4cbf27
...
@@ -21,7 +21,6 @@ import os
...
@@ -21,7 +21,6 @@ import os
import
re
import
re
import
sys
import
sys
import
time
import
time
from
dataclasses
import
dataclass
,
field
from
datetime
import
timedelta
from
datetime
import
timedelta
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -38,11 +37,7 @@ from huggingface_hub import HfApi
...
@@ -38,11 +37,7 @@ from huggingface_hub import HfApi
from
multiprocess
import
set_start_method
from
multiprocess
import
set_start_method
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
DataLoader
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
transformers
import
(
from
transformers
import
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
)
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.optimization
import
get_scheduler
from
transformers.optimization
import
get_scheduler
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
...
@@ -306,9 +301,7 @@ def main():
...
@@ -306,9 +301,7 @@ def main():
# update pad token id and decoder_start_token_id
# update pad token id and decoder_start_token_id
config
.
update
(
config
.
update
(
{
{
"pad_token_id"
:
model_args
.
pad_token_id
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
if
model_args
.
decoder_start_token_id
is
not
None
else
config
.
decoder_start_token_id
,
else
config
.
decoder_start_token_id
,
...
@@ -579,16 +572,18 @@ def main():
...
@@ -579,16 +572,18 @@ def main():
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
clap_score
=
clap_similarity
(
model_args
.
clap_model_name_or_path
,
texts
,
audios
,
device
)
clap_score
=
clap_similarity
(
model_args
.
clap_model_name_or_path
,
texts
,
audios
,
device
)
results
[
"clap"
]
=
clap_score
results
[
"clap"
]
=
clap_score
word_error
,
transcriptions
=
wer
(
model_args
.
asr_model_name_or_path
,
word_error
,
transcriptions
=
wer
(
prompts
,
model_args
.
asr_model_name_or_path
,
audios
,
prompts
,
device
,
audios
,
training_args
.
per_device_eval_batch_size
,
device
,
sampling_rate
)
training_args
.
per_device_eval_batch_size
,
sampling_rate
,
)
results
[
"wer"
]
=
word_error
results
[
"wer"
]
=
word_error
return
results
,
texts
,
prompts
,
audios
,
transcriptions
return
results
,
texts
,
prompts
,
audios
,
transcriptions
...
@@ -878,7 +873,9 @@ def main():
...
@@ -878,7 +873,9 @@ def main():
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
if
cur_step
==
total_train_steps
:
if
cur_step
==
total_train_steps
:
# un-wrap student model for save
# un-wrap student model for save
...
@@ -1020,4 +1017,4 @@ def main():
...
@@ -1020,4 +1017,4 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
set_start_method
(
"spawn"
)
set_start_method
(
"spawn"
)
main
()
main
()
\ No newline at end of file
training/utils.py
View file @
aa4cbf27
...
@@ -8,6 +8,7 @@ from typing import Dict, List
...
@@ -8,6 +8,7 @@ from typing import Dict, List
import
torch
import
torch
from
wandb
import
Audio
from
wandb
import
Audio
def
list_field
(
default
=
None
,
metadata
=
None
):
def
list_field
(
default
=
None
,
metadata
=
None
):
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
...
@@ -121,4 +122,4 @@ def log_pred(
...
@@ -121,4 +122,4 @@ def log_pred(
]
]
},
},
step
=
step
,
step
=
step
,
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment