Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
bdb03638
Unverified
Commit
bdb03638
authored
May 14, 2024
by
Yoach Lacombe
Committed by
GitHub
May 14, 2024
Browse files
Merge pull request #48 from ylacombe/pr/Wauplin/18
Pr/wauplin/18
parents
b2b749d1
3f5fd26c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
48 additions
and
37 deletions
+48
-37
parler_tts/configuration_parler_tts.py
parler_tts/configuration_parler_tts.py
+1
-1
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+6
-3
training/arguments.py
training/arguments.py
+7
-3
training/data.py
training/data.py
+2
-1
training/eval.py
training/eval.py
+3
-2
training/run_parler_tts_training.py
training/run_parler_tts_training.py
+27
-26
training/utils.py
training/utils.py
+2
-1
No files found.
parler_tts/configuration_parler_tts.py
View file @
bdb03638
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to 2049):
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
num_hidden_layers (`int`, *optional*, defaults to 24):
...
...
parler_tts/modeling_parler_tts.py
View file @
bdb03638
...
@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1522,7 +1522,7 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids
=
outputs
.
sequences
output_ids
=
outputs
.
sequences
else
:
else
:
output_ids
=
outputs
output_ids
=
outputs
# apply the pattern mask to the final ids
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
...
@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2460,7 +2460,10 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
"encoder_outputs"
not
in
model_kwargs
:
if
"encoder_outputs"
not
in
model_kwargs
:
# encoder_outputs are created and added to `model_kwargs`
# encoder_outputs are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_text_encoder_kwargs_for_generation
(
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
inputs_tensor
,
model_kwargs
,
model_input_name
,
generation_config
,
)
)
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
...
@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2667,4 +2670,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs
.
sequences
=
output_values
outputs
.
sequences
=
output_values
return
outputs
return
outputs
else
:
else
:
return
output_values
return
output_values
\ No newline at end of file
training/arguments.py
View file @
bdb03638
...
@@ -3,6 +3,7 @@ from typing import Optional
...
@@ -3,6 +3,7 @@ from typing import Optional
from
transformers
import
Seq2SeqTrainingArguments
from
transformers
import
Seq2SeqTrainingArguments
@
dataclass
@
dataclass
class
ModelArguments
:
class
ModelArguments
:
"""
"""
...
@@ -67,15 +68,18 @@ class ModelArguments:
...
@@ -67,15 +68,18 @@ class ModelArguments:
)
)
asr_model_name_or_path
:
str
=
field
(
asr_model_name_or_path
:
str
=
field
(
default
=
"distil-whisper/distil-large-v2"
,
default
=
"distil-whisper/distil-large-v2"
,
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute WER during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
)
clap_model_name_or_path
:
str
=
field
(
clap_model_name_or_path
:
str
=
field
(
default
=
"laion/larger_clap_music_and_speech"
,
default
=
"laion/larger_clap_music_and_speech"
,
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
}
metadata
=
{
"help"
:
"Used to compute audio similarity during evaluation. Path to pretrained model or model identifier from huggingface.co/models"
},
)
)
@
dataclass
@
dataclass
class
DataTrainingArguments
:
class
DataTrainingArguments
:
"""
"""
...
...
training/data.py
View file @
bdb03638
...
@@ -11,6 +11,7 @@ from tqdm import tqdm
...
@@ -11,6 +11,7 @@ from tqdm import tqdm
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
@
dataclass
@
dataclass
class
DataCollatorEncodecWithPadding
:
class
DataCollatorEncodecWithPadding
:
"""
"""
...
@@ -301,4 +302,4 @@ def load_multiple_datasets(
...
@@ -301,4 +302,4 @@ def load_multiple_datasets(
with
accelerator
.
main_process_first
():
with
accelerator
.
main_process_first
():
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
interleaved_dataset
=
concatenate_datasets
(
all_datasets
)
return
interleaved_dataset
return
interleaved_dataset
\ No newline at end of file
training/eval.py
View file @
bdb03638
import
torch
import
torch
import
evaluate
import
evaluate
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
from
transformers
import
AutoModel
,
AutoProcessor
,
pipeline
...
@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
...
@@ -20,6 +20,7 @@ def clap_similarity(clap_model_name_or_path, texts, audios, device):
clap_inputs
.
to
(
"cpu"
)
clap_inputs
.
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
return
cosine_sim
.
mean
().
to
(
"cpu"
)
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
def
wer
(
asr_model_name_or_path
,
prompts
,
audios
,
device
,
per_device_eval_batch_size
,
sampling_rate
):
metric
=
evaluate
.
load
(
"wer"
)
metric
=
evaluate
.
load
(
"wer"
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
asr_pipeline
=
pipeline
(
model
=
asr_model_name_or_path
,
device
=
device
)
...
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
...
@@ -32,4 +33,4 @@ def wer(asr_model_name_or_path, prompts, audios, device, per_device_eval_batch_s
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
predictions
=
[
t
[
"text"
].
lower
()
for
t
in
transcriptions
],
references
=
[
t
.
lower
()
for
t
in
prompts
]
)
)
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
return
word_error
,
[
t
[
"text"
]
for
t
in
transcriptions
]
\ No newline at end of file
training/run_parler_tts_training.py
View file @
bdb03638
...
@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader
...
@@ -33,24 +33,22 @@ from torch.utils.data import DataLoader
import
datasets
import
datasets
from
datasets
import
DatasetDict
,
Dataset
,
IterableDataset
,
concatenate_datasets
from
datasets
import
DatasetDict
,
Dataset
,
IterableDataset
,
concatenate_datasets
from
huggingface_hub
import
Repository
,
create_repo
from
huggingface_hub
import
HfApi
import
transformers
import
transformers
from
transformers
import
(
from
transformers
import
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
AutoFeatureExtractor
,
AutoTokenizer
,
HfArgumentParser
)
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.trainer_pt_utils
import
LengthGroupedSampler
from
transformers.optimization
import
get_scheduler
from
transformers.optimization
import
get_scheduler
from
transformers.utils
import
send_example_telemetry
from
transformers.utils
import
send_example_telemetry
from
accelerate
import
Accelerator
from
accelerate
import
Accelerator
from
accelerate.utils
import
set_seed
,
AutocastKwargs
,
InitProcessGroupKwargs
,
TorchDynamoPlugin
from
accelerate.utils
import
set_seed
,
AutocastKwargs
,
InitProcessGroupKwargs
,
TorchDynamoPlugin
from
accelerate.utils.memory
import
release_memory
from
accelerate.utils.memory
import
release_memory
from
parler_tts
import
(
from
parler_tts
import
(
ParlerTTSForConditionalGeneration
,
ParlerTTSConfig
,
ParlerTTSConfig
,
ParlerTTSForConditionalGeneration
,
build_delay_pattern_mask
,
build_delay_pattern_mask
,
)
)
...
@@ -301,9 +299,7 @@ def main():
...
@@ -301,9 +299,7 @@ def main():
# update pad token id and decoder_start_token_id
# update pad token id and decoder_start_token_id
config
.
update
(
config
.
update
(
{
{
"pad_token_id"
:
model_args
.
pad_token_id
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
if
model_args
.
pad_token_id
is
not
None
else
config
.
pad_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
if
model_args
.
decoder_start_token_id
is
not
None
else
config
.
decoder_start_token_id
,
else
config
.
decoder_start_token_id
,
...
@@ -574,16 +570,18 @@ def main():
...
@@ -574,16 +570,18 @@ def main():
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
texts
=
description_tokenizer
.
batch_decode
(
input_ids
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
prompts
=
prompt_tokenizer
.
batch_decode
(
prompts
,
skip_special_tokens
=
True
)
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
audios
=
[
a
.
cpu
().
numpy
()
for
a
in
audios
]
clap_score
=
clap_similarity
(
model_args
.
clap_model_name_or_path
,
texts
,
audios
,
device
)
clap_score
=
clap_similarity
(
model_args
.
clap_model_name_or_path
,
texts
,
audios
,
device
)
results
[
"clap"
]
=
clap_score
results
[
"clap"
]
=
clap_score
word_error
,
transcriptions
=
wer
(
model_args
.
asr_model_name_or_path
,
word_error
,
transcriptions
=
wer
(
prompts
,
model_args
.
asr_model_name_or_path
,
audios
,
prompts
,
device
,
audios
,
training_args
.
per_device_eval_batch_size
,
device
,
sampling_rate
)
training_args
.
per_device_eval_batch_size
,
sampling_rate
,
)
results
[
"wer"
]
=
word_error
results
[
"wer"
]
=
word_error
return
results
,
texts
,
prompts
,
audios
,
transcriptions
return
results
,
texts
,
prompts
,
audios
,
transcriptions
...
@@ -673,14 +671,13 @@ def main():
...
@@ -673,14 +671,13 @@ def main():
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
if
training_args
.
push_to_hub
:
if
training_args
.
push_to_hub
:
# Retrieve of infer repo_name
api
=
HfApi
(
token
=
training_args
.
hub_token
)
# Create repo (repo_name from args or inferred)
repo_name
=
training_args
.
hub_model_id
repo_name
=
training_args
.
hub_model_id
if
repo_name
is
None
:
if
repo_name
is
None
:
repo_name
=
Path
(
training_args
.
output_dir
).
absolute
().
name
repo_name
=
Path
(
training_args
.
output_dir
).
absolute
().
name
# Create repo and retrieve repo_id
repo_id
=
api
.
create_repo
(
repo_name
,
exist_ok
=
True
).
repo_id
repo_id
=
create_repo
(
repo_name
,
exist_ok
=
True
,
token
=
training_args
.
hub_token
).
repo_id
# Clone repo locally
repo
=
Repository
(
training_args
.
output_dir
,
clone_from
=
repo_id
,
token
=
training_args
.
hub_token
)
with
open
(
os
.
path
.
join
(
training_args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
with
open
(
os
.
path
.
join
(
training_args
.
output_dir
,
".gitignore"
),
"w+"
)
as
gitignore
:
if
"wandb"
not
in
gitignore
:
if
"wandb"
not
in
gitignore
:
...
@@ -874,7 +871,9 @@ def main():
...
@@ -874,7 +871,9 @@ def main():
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
save_state
(
output_dir
=
intermediate_dir
,
safe_serialization
=
False
)
accelerator
.
wait_for_everyone
()
accelerator
.
wait_for_everyone
()
if
accelerator
.
is_main_process
:
if
accelerator
.
is_main_process
:
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
rotate_checkpoints
(
training_args
.
save_total_limit
,
output_dir
=
training_args
.
output_dir
,
logger
=
logger
)
if
cur_step
==
total_train_steps
:
if
cur_step
==
total_train_steps
:
# un-wrap student model for save
# un-wrap student model for save
...
@@ -882,9 +881,11 @@ def main():
...
@@ -882,9 +881,11 @@ def main():
unwrapped_model
.
save_pretrained
(
training_args
.
output_dir
)
unwrapped_model
.
save_pretrained
(
training_args
.
output_dir
)
if
training_args
.
push_to_hub
:
if
training_args
.
push_to_hub
:
repo
.
push_to_hub
(
api
.
upload_folder
(
repo_id
=
repo_id
,
folder_path
=
training_args
.
output_dir
,
commit_message
=
f
"Saving train state of step
{
cur_step
}
"
,
commit_message
=
f
"Saving train state of step
{
cur_step
}
"
,
blocking
=
Fals
e
,
run_as_future
=
Tru
e
,
)
)
if
training_args
.
do_eval
and
(
cur_step
%
eval_steps
==
0
or
cur_step
==
total_train_steps
):
if
training_args
.
do_eval
and
(
cur_step
%
eval_steps
==
0
or
cur_step
==
total_train_steps
):
...
@@ -1014,4 +1015,4 @@ def main():
...
@@ -1014,4 +1015,4 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
set_start_method
(
"spawn"
)
set_start_method
(
"spawn"
)
main
()
main
()
\ No newline at end of file
training/utils.py
View file @
bdb03638
...
@@ -8,6 +8,7 @@ from typing import Dict, List
...
@@ -8,6 +8,7 @@ from typing import Dict, List
import
torch
import
torch
from
wandb
import
Audio
from
wandb
import
Audio
def
list_field
(
default
=
None
,
metadata
=
None
):
def
list_field
(
default
=
None
,
metadata
=
None
):
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
...
@@ -121,4 +122,4 @@ def log_pred(
...
@@ -121,4 +122,4 @@ def log_pred(
]
]
},
},
step
=
step
,
step
=
step
,
)
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment