Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
bdb03638
Unverified
Commit
bdb03638
authored
May 14, 2024
by
Yoach Lacombe
Committed by
GitHub
May 14, 2024
Browse files
Merge pull request #48 from ylacombe/pr/Wauplin/18
Pr/wauplin/18
parents
b2b749d1
3f5fd26c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
37 deletions
+48
-37
parler_tts/configuration_parler_tts.py
parler_tts/configuration_parler_tts.py
+1
-1
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+6
-3
training/arguments.py
training/arguments.py
+7
-3
training/data.py
training/data.py
+2
-1
training/eval.py
training/eval.py
+3
-2
training/run_parler_tts_training.py
training/run_parler_tts_training.py
+27
-26
training/utils.py
training/utils.py
+2
-1
No files found.
parler_tts/configuration_parler_tts.py
View file @
bdb03638
...
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
...
...
parler_tts/modeling_parler_tts.py
View file @
bdb03638
...
...
@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids
=
outputs
.
sequences
else
:
output_ids
=
outputs
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
...
...
@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
)
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
...
...
@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs
.
sequences
=
output_values
return
outputs
else
:
return
output_values
\ No newline at end of file
return
output_values
training/arguments.py
View file @
bdb03638
...
...
@@ -3,6 +3,7 @@ from typing import Optional
from
transformers
import
Seq2SeqTrainingArguments
@
dataclass
class
ModelArguments
:
"""
...
...
@@ -67,15 +68,18 @@ class ModelArguments:
)
asr_model_name_or_path
:
str
=
field
(
default
=
"distil-whisper/distil-large-v2"
,
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
clap_model_name_or_path
:
str
=
field
(
default
=
"laion/larger_clap_music_and_speech"
,
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
@
dataclass
class
DataTrainingArguments
:
"""
...
...
training/data.py
View file @
bdb03638
...
...
@@ -11,6 +11,7 @@ from tqdm import tqdm
from
accelerate
import
Accelerator
@
dataclass
class
DataCollatorEncodecWithPadding
:
"""
...
...
@@ -301,4 +302,4 @@ def load_multiple_datasets(
with
accelerator
.
main_process_first
():
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
return
interleaved_dataset
\ No newline at end of file
return
interleaved_dataset
training/eval.py
View file @
bdb03638
import
torch
import
torch
import
evaluate
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
...
...
@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
metric
=
evaluate
.
load
(
"wer"
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
...
...
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
)
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
\ No newline at end of file
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
training/run_parler_tts_training.py
View file @
bdb03638
...
...
@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader
import
datasets
from
datasets
import
DatasetDict
,
Dataset
,
IterableDataset
,
concatenate_datasets
from
huggingface_hub
import
Repository
,
create_repo
from
huggingface_hub
import
HfApi
import
transformers
from
transformers
import
(
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
)
from
transformers
import
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.optimization
import
get_scheduler
from
transformers.utils
import
send_example_telemetry
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
,
AutocastKwargs
,
InitProcessGroupKwargs
,
TorchDynamoPlugin
from
accelerate.utils.memory
import
release_memory
from
parler_tts
import
(
ParlerTTSForConditionalGeneration
,
ParlerTTSConfig
,
ParlerTTSForConditionalGeneration
,
build_delay_pattern_mask
,
)
...
...
@@ -301,9 +299,7 @@ def main():
# update pad token id and decoder_start_token_id
config
.
update
(
{
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
config
.
decoder_start_token_id
,
...
...
@@ -574,16 +570,18 @@ def main():
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
clap_score
=
clap_similarity
(
model_args
.
clap_model_name_or_path
,
texts
,
audios
,
device
)
results
[
"clap"
]
=
clap_score
word_error
,
transcriptions
=
wer
(
model_args
.
asr_model_name_or_path
,
prompts
,
audios
,
device
,
training_args
.
per_device_eval_batch_size
,
sampling_rate
)
word_error
,
transcriptions
=
wer
(
model_args
.
asr_model_name_or_path
,
prompts
,
audios
,
device
,
training_args
.
per_device_eval_batch_size
,
sampling_rate
,
)
results
[
"wer"
]
=
word_error
return
results
,
texts
,
prompts
,
audios
,
transcriptions
...
...
@@ -673,14 +671,13 @@ def main():
if
accelerator
.
is_main_process
:
if
training_args
.
push_to_hub
:
# Retrieve of infer repo_name
api
=
HfApi
(
token
=
training_args
.
hub_token
)
# Create repo (repo_name from args or inferred)
repo_name
=
training_args
.
hub_model_id
if
repo_name
is
None
:
repo_name
=
Path
(
training_args
.
output_dir
).
absolute
().
name
# Create repo and retrieve repo_id
repo_id
=
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
training_args
.
hub_token
).
repo_id
# Clone repo locally
repo
=
Repository
(
training_args
.
output_dir
,
clone_from
=
repo_id
,
token
=
training_args
.
hub_token
)
repo_id
=
api
.
create_repo
(
repo_name
,
exist_ok
=
True
).
repo_id
with
open
(
os
.
path
.
join
(
training_args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"wandb"
not
in
gitignore
:
...
...
@@ -874,7 +871,9 @@ def main():
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
if
cur_step
==
total_train_steps
:
# un-wrap student model for save
...
...
@@ -882,9 +881,11 @@ def main():
unwrapped_model
.
save_pretrained
(
training_args
.
output_dir
)
if
training_args
.
push_to_hub
:
repo
.
push_to_hub
(
api
.
upload_folder
(
repo_id
=
repo_id
,
folder_path
=
training_args
.
output_dir
,
commit_message
=
f
"Saving train state of step
{
cur_step
}
"
,
blocking
=
Fals
e
,
run_as_future
=
Tru
e
,
)
if
training_args
.
do_eval
and
(
cur_step
%
eval_steps
==
0
or
cur_step
==
total_train_steps
):
...
...
@@ -1014,4 +1015,4 @@ def main():
if
__name__
==
"__main__"
:
set_start_method
(
"spawn"
)
main
()
\ No newline at end of file
main
()
training/utils.py
View file @
bdb03638
...
...
@@ -8,6 +8,7 @@ from typing import Dict, List
import
torch
from
wandb
import
Audio
def
list_field
(
default
=
None
,
metadata
=
None
):
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
...
...
@@ -121,4 +122,4 @@ def log_pred(
]
},
step
=
step
,
)
\ No newline at end of file
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment