Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
fc66e60b
Commit
fc66e60b
authored
Feb 21, 2024
by
Yoach Lacombe
Browse files
fix some bugs
parent
b6341055
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
38 deletions
+40
-38
run_stable_speech_training.py
run_stable_speech_training.py
+28
-26
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+4
-4
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+8
-8
No files found.
run_stable_speech_training.py
View file @
fc66e60b
...
@@ -30,7 +30,6 @@ from dataclasses import dataclass, field
...
@@ -30,7 +30,6 @@ from dataclasses import dataclass, field
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
datasets
import
datasets
import
evaluate
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
datasets
import
DatasetDict
,
load_dataset
,
Dataset
,
IterableDataset
,
interleave_datasets
,
concatenate_datasets
from
datasets
import
DatasetDict
,
load_dataset
,
Dataset
,
IterableDataset
,
interleave_datasets
,
concatenate_datasets
...
@@ -315,7 +314,7 @@ class DataSeq2SeqTrainingArguments:
...
@@ -315,7 +314,7 @@ class DataSeq2SeqTrainingArguments:
@
dataclass
@
dataclass
class
DataCollator
MusicGen
WithPadding
:
class
DataCollator
StableSpeech
WithPadding
:
"""
"""
Data collator that will dynamically pad the inputs received.
Data collator that will dynamically pad the inputs received.
Args:
Args:
...
@@ -360,16 +359,14 @@ class DataCollatorMusicGenWithPadding:
...
@@ -360,16 +359,14 @@ class DataCollatorMusicGenWithPadding:
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
{
"labels"
:
labels
,
**
input_ids
}
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
if
"attention_mask"
in
prompt_input_ids
:
if
"attention_mask"
in
prompt_input_ids
:
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
batch
=
{
"labels"
:
labels
,
**
input_ids
}
if
self
.
feature_extractor_input_name
in
features
[
0
]:
if
self
.
feature_extractor_input_name
in
features
[
0
]:
# TODO: verify that it works
# TODO: verify that it works
...
@@ -485,29 +482,30 @@ def load_multiple_datasets(
...
@@ -485,29 +482,30 @@ def load_multiple_datasets(
**
kwargs
,
**
kwargs
,
)
)
if
id_column_name
is
not
None
and
id_column_name
not
in
dataset
:
if
id_column_name
is
not
None
and
id_column_name
not
in
dataset
.
column_names
:
raise
ValueError
(
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the dataset columns"
f
"id_column_name=
{
id_column_name
}
but has not been found in the dataset columns"
f
"- one of
{
', '
.
join
(
list
(
dataset
.
columns
))
}
."
f
"- one of
{
', '
.
join
(
list
(
dataset
.
column
_name
s
))
}
."
)
)
if
id_column_name
is
not
None
and
id_column_name
not
in
metadata_dataset
:
if
id_column_name
is
not
None
and
id_column_name
not
in
metadata_dataset
.
column_names
:
raise
ValueError
(
raise
ValueError
(
f
"id_column_name=
{
id_column_name
}
but has not been found in the metadata dataset columns"
f
"id_column_name=
{
id_column_name
}
but has not been found in the metadata dataset columns"
f
"- one of
{
', '
.
join
(
list
(
metadata_dataset
.
columns
))
}
."
f
"- one of
{
', '
.
join
(
list
(
metadata_dataset
.
column
_name
s
))
}
."
)
)
elif
id_column_name
is
not
None
:
elif
id_column_name
is
not
None
:
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
metadata_dataset
=
metadata_dataset
.
rename_column
(
id_column_name
,
f
"metadata_
{
id_column_name
}
"
)
metadata_columns_to_
keep
=
set
(
metadata_dataset
.
columns
).
intersection
(
set
(
dataset
.
column_names
))
metadata_columns_to_
remove
=
set
(
metadata_dataset
.
column
_name
s
).
intersection
(
set
(
dataset
.
column_names
))
metadata_dataset
=
metadata_dataset
.
remove_columns
(
set
(
metadata_dataset
.
columns
)
-
metadata_columns_to_
keep
)
metadata_dataset
=
metadata_dataset
.
remove_columns
(
metadata_columns_to_
remove
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
dataset
=
concatenate_datasets
([
dataset
,
metadata_dataset
],
axis
=
1
)
if
id_column_name
is
not
None
:
if
id_column_name
is
not
None
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
if
len
(
dataset
.
filter
(
lambda
id1
,
id2
:
id1
!=
id2
,
input_columns
=
[
id_column_name
,
f
"metadata_
{
id_column_name
}
"
]))
!=
0
:
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
"
name
"
]
}
"
)
raise
ValueError
(
f
"Concatenate didn't work. Some ids don't correspond on dataset
{
dataset_dict
[
'
name
'
]
}
"
)
dataset_features
=
dataset
.
features
.
keys
()
if
columns_to_keep
is
not
None
:
if
columns_to_keep
is
not
None
:
dataset
=
dataset
.
remove_columns
(
set
(
dataset_features
-
columns_to_keep
))
dataset
=
dataset
.
remove_columns
(
set
(
dataset_features
-
columns_to_keep
))
all_datasets
.
append
(
dataset
)
all_datasets
.
append
(
dataset
)
...
@@ -586,9 +584,14 @@ def main():
...
@@ -586,9 +584,14 @@ def main():
# 1. First, let's load the dataset
# 1. First, let's load the dataset
raw_datasets
=
DatasetDict
()
raw_datasets
=
DatasetDict
()
num_workers
=
data_args
.
preprocessing_num_workers
num_workers
=
data_args
.
preprocessing_num_workers
if
training_args
.
do_train
:
columns_to_keep
=
[
data_args
.
target_audio_column_name
,
data_args
.
prompt_column_name
]
if
data_args
.
description_column_name
is
not
None
:
columns_to_keep
.
append
(
data_args
.
description_column_name
)
if
data_args
.
conditional_audio_column_name
is
not
None
:
columns_to_keep
.
append
(
data_args
.
conditional_audio_column_name
)
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
raw_datasets
[
"train"
]
=
load_multiple_datasets
(
data_args
.
train_dataset_name
,
data_args
.
train_dataset_name
,
data_args
.
train_dataset_config_name
,
data_args
.
train_dataset_config_name
,
...
@@ -597,10 +600,9 @@ def main():
...
@@ -597,10 +600,9 @@ def main():
dataset_samples
=
data_args
.
train_dataset_samples
,
dataset_samples
=
data_args
.
train_dataset_samples
,
seed
=
training_args
.
seed
,
seed
=
training_args
.
seed
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
token
=
True
if
model_args
.
token
else
None
,
trust_remote_code
=
model_args
.
trust_remote_code
,
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
)
...
@@ -645,10 +647,9 @@ def main():
...
@@ -645,10 +647,9 @@ def main():
data_args
.
eval_metadata_dataset_name
,
data_args
.
eval_metadata_dataset_name
,
splits
=
data_args
.
eval_split_name
,
splits
=
data_args
.
eval_split_name
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
token
=
True
if
model_args
.
token
else
None
,
trust_remote_code
=
model_args
.
trust_remote_code
,
num_proc
=
data_args
.
preprocessing_num_workers
,
num_proc
=
data_args
.
preprocessing_num_workers
,
id_column_name
=
data_args
.
id_column_name
,
id_column_name
=
data_args
.
id_column_name
,
columns_to_keep
=
columns_to_keep
,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
)
...
@@ -752,11 +753,11 @@ def main():
...
@@ -752,11 +753,11 @@ def main():
if
description_column_name
is
not
None
:
if
description_column_name
is
not
None
:
text
=
batch
[
description_column_name
]
text
=
batch
[
description_column_name
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
)[
"input_ids"
]
batch
[
"input_ids"
]
=
description_tokenizer
(
text
.
strip
()
)[
"input_ids"
]
if
prompt_column_name
is
not
None
:
if
prompt_column_name
is
not
None
:
text
=
batch
[
prompt_column_name
]
text
=
batch
[
prompt_column_name
]
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
text
)[
"input_ids"
]
batch
[
"prompt_input_ids"
]
=
prompt_tokenizer
(
text
.
strip
()
)[
"input_ids"
]
# load audio
# load audio
target_sample
=
batch
[
target_audio_column_name
]
target_sample
=
batch
[
target_audio_column_name
]
...
@@ -878,8 +879,8 @@ def main():
...
@@ -878,8 +879,8 @@ def main():
config
.
save_pretrained
(
training_args
.
output_dir
)
config
.
save_pretrained
(
training_args
.
output_dir
)
# Instantiate custom data collator
# Instantiate custom data collator
data_collator
=
DataCollator
MusicGen
WithPadding
(
data_collator
=
DataCollator
StableSpeech
WithPadding
(
feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
audio_
feature_extractor
=
feature_extractor
,
feature_extractor_input_name
=
feature_extractor_input_name
,
prompt_tokenizer
=
prompt_tokenizer
,
description_tokenizer
=
description_tokenizer
)
)
# Freeze Encoders
# Freeze Encoders
...
@@ -956,8 +957,9 @@ def main():
...
@@ -956,8 +957,9 @@ def main():
# use last checkpoint if exist
# use last checkpoint if exist
if
last_checkpoint
is
not
None
:
if
last_checkpoint
is
not
None
:
checkpoint
=
last_checkpoint
checkpoint
=
last_checkpoint
elif
os
.
path
.
isdir
(
model_args
.
model_name_or_path
):
# TODO: it's loading trainer from model_name_or_path doesn't work if saving config
checkpoint
=
model_args
.
model_name_or_path
# elif os.path.isdir(model_args.model_name_or_path):
# checkpoint = model_args.model_name_or_path
else
:
else
:
checkpoint
=
None
checkpoint
=
None
...
...
stable_speech/configuration_stable_speech.py
View file @
fc66e60b
...
@@ -137,8 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
...
@@ -137,8 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
documentation from [`PretrainedConfig`] for more information.
Args:
Args:
prompt_embed_dim
(`int`, *optional*, defaults to 1024):
vocab_size
(`int`, *optional*, defaults to 1024):
Dimensionality of the prompt embedding layer
.
Vocabulary size of the prompt # TODO
.
kwargs (*optional*):
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
Dictionary of keyword arguments. Notably:
...
@@ -189,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
...
@@ -189,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
model_type
=
"stable_speech"
model_type
=
"stable_speech"
is_composition
=
True
is_composition
=
True
def
__init__
(
self
,
prompt_embed_dim
=
1024
,
**
kwargs
):
def
__init__
(
self
,
vocab_size
=
1024
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
if
"text_encoder"
not
in
kwargs
or
"audio_encoder"
not
in
kwargs
or
"decoder"
not
in
kwargs
:
if
"text_encoder"
not
in
kwargs
or
"audio_encoder"
not
in
kwargs
or
"decoder"
not
in
kwargs
:
raise
ValueError
(
"Config has to be initialized with text_encoder, audio_encoder and decoder config"
)
raise
ValueError
(
"Config has to be initialized with text_encoder, audio_encoder and decoder config"
)
...
@@ -202,7 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
...
@@ -202,7 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config
=
kwargs
.
pop
(
"decoder"
)
decoder_config
=
kwargs
.
pop
(
"decoder"
)
self
.
prompt_embed_dim
=
prompt_embed_dim
self
.
vocab_size
=
vocab_size
self
.
text_encoder
=
AutoConfig
.
for_model
(
text_encoder_model_type
,
**
text_encoder_config
)
self
.
text_encoder
=
AutoConfig
.
for_model
(
text_encoder_model_type
,
**
text_encoder_config
)
self
.
audio_encoder
=
AutoConfig
.
for_model
(
audio_encoder_model_type
,
**
audio_encoder_config
)
self
.
audio_encoder
=
AutoConfig
.
for_model
(
audio_encoder_model_type
,
**
audio_encoder_config
)
self
.
decoder
=
StableSpeechDecoderConfig
(
**
decoder_config
)
self
.
decoder
=
StableSpeechDecoderConfig
(
**
decoder_config
)
...
...
stable_speech/modeling_stable_speech.py
View file @
fc66e60b
...
@@ -730,7 +730,6 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
...
@@ -730,7 +730,6 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if
prompt_hidden_states
is
not
None
:
if
prompt_hidden_states
is
not
None
:
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
# TODO: verify if prompt attention mask is required
# TODO: verify if prompt attention mask is required
# As it is, the masked ids from the prompt will still count in the positions embeddings
# As it is, the masked ids from the prompt will still count in the positions embeddings
...
@@ -740,9 +739,9 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
...
@@ -740,9 +739,9 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
logger
.
warning_once
(
logger
.
warning_once
(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
)
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
(
input_shape
,
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)])
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
(
input_shape
,
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)]
,
dim
=
1
)
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
)
)
...
@@ -1538,7 +1537,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1538,7 +1537,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
self
.
enc_to_dec_proj
=
nn
.
Linear
(
self
.
text_encoder
.
config
.
hidden_size
,
self
.
decoder
.
config
.
hidden_size
)
self
.
enc_to_dec_proj
=
nn
.
Linear
(
self
.
text_encoder
.
config
.
hidden_size
,
self
.
decoder
.
config
.
hidden_size
)
# prompt embeddings
# prompt embeddings
self
.
embed_prompts
=
nn
.
Embedding
(
config
.
prompt_embed_dim
,
self
.
decoder
.
config
.
hidden_size
)
self
.
embed_prompts
=
nn
.
Embedding
(
config
.
vocab_size
,
self
.
decoder
.
config
.
hidden_size
)
if
self
.
text_encoder
.
get_output_embeddings
()
is
not
None
:
if
self
.
text_encoder
.
get_output_embeddings
()
is
not
None
:
...
@@ -1557,7 +1556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1557,7 +1556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
self
.
post_init
()
self
.
post_init
()
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_factor
std
=
self
.
decoder
.
config
.
initializer_factor
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Conv1d
)):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Conv1d
)):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
if
module
.
bias
is
not
None
:
...
@@ -1787,7 +1786,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1787,7 +1786,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
)
)
if
"config"
not
in
kwargs_decoder
:
if
"config"
not
in
kwargs_decoder
:
decoder_config
,
kwargs_decoder
=
AutoConfig
.
from_pretrained
(
# TODO: reput AutoConfig once added to transformers
decoder_config
,
kwargs_decoder
=
StableSpeechDecoderConfig
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
,
return_unused_kwargs
=
True
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
,
return_unused_kwargs
=
True
)
)
...
@@ -1923,7 +1923,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1923,7 +1923,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
decoder_input_ids
=
shift_tokens_right
(
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
labels
.
transpose
(
1
,
2
)
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
)
elif
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
elif
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
...
@@ -2190,7 +2190,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2190,7 +2190,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return
model_kwargs
return
model_kwargs
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
.
transpose
(
1
,
2
)
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
# TODO: now it's possible with prompt_embeddings
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment