Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
11fcc066
Commit
11fcc066
authored
Feb 23, 2024
by
Yoach Lacombe
Browse files
working training + generation
parent
997bf5e6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
54 deletions
+88
-54
example_configs/librispeech_tts_r.json
example_configs/librispeech_tts_r.json
+4
-4
init_dummy_model.py
init_dummy_model.py
+2
-1
run_stable_speech_training.py
run_stable_speech_training.py
+28
-16
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+4
-4
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+50
-29
No files found.
example_configs/librispeech_tts_r.json
View file @
11fcc066
...
...
@@ -36,14 +36,14 @@
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
204
8
,
"pad_token_id"
:
204
9
,
"decoder_start_token_id"
:
2048
,
"do_train"
:
true
,
"num_train_epochs"
:
20
,
"num_train_epochs"
:
1
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
true
,
"per_device_train_batch_size"
:
16
,
"per_device_train_batch_size"
:
8
,
"learning_rate"
:
1e-6
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.95
,
...
...
@@ -56,7 +56,7 @@
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"epoch"
,
"per_device_eval_batch_size"
:
16
,
"per_device_eval_batch_size"
:
8
,
"generation_max_length"
:
400
,
"fp16"
:
true
,
...
...
init_dummy_model.py
View file @
11fcc066
...
...
@@ -35,7 +35,8 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
2048
model
.
generation_config
.
pad_token_id
=
2048
model
.
generation_config
.
pad_token_id
=
2049
model
.
generation_config
.
eos_token_id
=
2049
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
...
...
run_stable_speech_training.py
View file @
11fcc066
...
...
@@ -77,16 +77,10 @@ def list_field(default=None, metadata=None):
class
StableSpeechTrainer
(
Seq2SeqTrainer
):
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
):
if
self
.
tokenizer
is
not
None
and
hasattr
(
self
.
tokenizer
,
"pad_token_id"
):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id
=
(
self
.
tokenizer
.
pad_token_id
if
self
.
tokenizer
.
pad_token_id
is
not
None
else
self
.
tokenizer
.
eos_token_id
)
if
self
.
model
.
config
.
pad_token_id
is
not
None
:
pad_token_id
=
self
.
model
.
config
.
pad_token_id
else
:
if
self
.
model
.
config
.
pad_token_id
is
not
None
:
pad_token_id
=
self
.
model
.
config
.
pad_token_id
else
:
raise
ValueError
(
"Pad_token_id must be set in the configuration of the model, in order to pad tensors"
)
raise
ValueError
(
"Pad_token_id must be set in the configuration of the model, in order to pad tensors"
)
padded_tensor
=
pad_token_id
*
torch
.
ones
(
(
tensor
.
shape
[
0
],
max_length
,
tensor
.
shape
[
2
]),
dtype
=
tensor
.
dtype
,
device
=
tensor
.
device
...
...
@@ -387,6 +381,7 @@ class DataCollatorStableSpeechWithPadding:
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
# TODO: check it's been padded on the left
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
if
"attention_mask"
in
prompt_input_ids
:
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
...
...
@@ -676,6 +671,7 @@ def main():
)
# update pad token id and decoder_start_token_id
# TODO: verify if this makes sense, maybe should do it for model.decoder
config
.
update
({
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
model
.
config
.
pad_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
model
.
config
.
decoder_start_token_id
,
...
...
@@ -700,6 +696,7 @@ def main():
token
=
data_args
.
token
,
trust_remote_code
=
data_args
.
trust_remote_code
,
use_fast
=
model_args
.
use_fast_tokenizer
,
padding_side
=
"left"
,
# prompt has to be padded on the left bc it's preprend to codebooks hidden states
)
# load description tokenizer
...
...
@@ -740,6 +737,10 @@ def main():
description_column_name
=
data_args
.
description_column_name
prompt_column_name
=
data_args
.
prompt_column_name
feature_extractor_input_name
=
feature_extractor
.
model_input_names
[
0
]
audio_encoder_eos_token_id
=
config
.
decoder
.
pad_token_id
audio_encoder_bos_token_id
=
model
.
generation_config
.
decoder_start_token_id
max_length
=
model
.
generation_config
.
max_length
num_codebooks
=
model
.
decoder
.
config
.
num_codebooks
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
...
...
@@ -794,7 +795,6 @@ def main():
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
...
...
@@ -832,18 +832,28 @@ def main():
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
# (1, codebooks, seq_len) where seq_len=1
eos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_eos_token_id
def
postprocess_dataset
(
sample
,
idx
):
# (1,
seq_len,
codebooks,
bsz
)
# (1, codebooks,
seq_len
)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
labels
=
labels
[:,
:,
:
len_
]
# add eos token column
labels
=
torch
.
cat
([
labels
,
eos_labels
.
to
(
labels
.
device
).
to
(
labels
.
dtype
)],
dim
=-
1
)
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
model
.
generation_config
.
decoder_start_token_id
,
model
.
generation_config
.
max_length
+
model
.
decoder
.
config
.
num_codebooks
)
audio_encoder_bos_token_id
,
audio_encoder_eos_token_id
,
max_length
+
num_codebooks
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample
[
"labels"
]
=
labels
[:,
1
:
len_
]
sample
[
"labels"
]
=
labels
[:,
1
:]
return
sample
# TODO: done multiple times, how to deal with it.
...
...
@@ -956,7 +966,7 @@ def main():
"""Custom WandbCallback to log model predictions during training.
"""
def
__init__
(
self
,
trainer
,
val_dataset
,
def
__init__
(
self
,
trainer
,
val_dataset
,
description_tokenizer
,
# TODO: add
num_samples
=
8
):
"""Initializes the WandbPredictionProgressCallback instance.
...
...
@@ -969,6 +979,7 @@ def main():
"""
super
().
__init__
()
self
.
trainer
=
trainer
self
.
description_tokenizer
=
description_tokenizer
self
.
sample_dataset
=
val_dataset
.
select
(
range
(
num_samples
))
def
on_train_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
...
...
@@ -992,6 +1003,7 @@ def main():
progress_callback
=
WandbPredictionProgressCallback
(
trainer
=
trainer
,
val_dataset
=
vectorized_datasets
[
"eval"
],
description_tokenizer
=
description_tokenizer
,
num_samples
=
8
,
# TODO: add to args
)
...
...
stable_speech/configuration_stable_speech.py
View file @
11fcc066
...
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 204
8
):
vocab_size (`int`, *optional*, defaults to 204
9
):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
...
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def
__init__
(
self
,
vocab_size
=
204
8
,
vocab_size
=
204
9
,
# vocab size = 2048 (encodec vocab size) + 1 (eos token)
max_position_embeddings
=
2048
,
num_hidden_layers
=
24
,
ffn_dim
=
4096
,
...
...
@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor
=
0.02
,
scale_embedding
=
False
,
num_codebooks
=
4
,
pad_token_id
=
204
8
,
pad_token_id
=
204
9
,
bos_token_id
=
2048
,
eos_token_id
=
None
,
eos_token_id
=
2049
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
...
...
stable_speech/modeling_stable_speech.py
View file @
11fcc066
...
...
@@ -731,7 +731,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if
prompt_hidden_states
is
not
None
:
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
# TODO: verify if prompt attention mask is required
# TODO: verify if prompt attention mask is required
and has to be
# As it is, the masked ids from the prompt will still count in the positions embeddings
if
prompt_attention_mask
is
not
None
and
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
attention_mask
],
dim
=
1
)
...
...
@@ -754,6 +754,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
)
# embed positions
# TODO: As it is, the masked ids from the prompt will still count in the positions embeddings
# maybe should modify position embeddings
positions
=
self
.
embed_positions
(
inputs_embeds
,
past_key_values_length
)
hidden_states
=
inputs_embeds
+
positions
.
to
(
inputs_embeds
.
device
)
...
...
@@ -1064,7 +1066,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if
delay_pattern_mask
is
None
:
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
pad_token_id
=
self
.
generation_config
.
pad_token_id
,
bos_token_id
=
self
.
generation_config
.
decoder_start_token_id
,
eos_token_id
=
self
.
generation_config
.
eos_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
)
...
...
@@ -1108,22 +1111,22 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
# Ignore copy
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
pad
_token_id
:
int
,
max_length
:
int
=
None
):
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
eos
_token_id
:
int
,
max_length
:
int
=
None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [
P
, -1, -1, -1, -1,
P
,
P
,
P
]
- [
P
,
P
, -1, -1, -1, -1,
P
,
P
]
- [
P
,
P
,
P
, -1, -1, -1, -1,
P
]
- [
P
,
P
,
P
,
P
, -1, -1, -1, -1]
where
P
is the
special padding
token id and -1 indicates that the token is valid for prediction. If we include
- [
B
, -1, -1, -1, -1,
E
,
E
,
E
]
- [
B
,
B
, -1, -1, -1, -1,
E
,
E
]
- [
B
,
B
,
B
, -1, -1, -1, -1,
E
]
- [
B
,
B
,
B
,
B
, -1, -1, -1, -1]
where
B
is the
BOS token id, E is the EOS
token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [
P
, a, b, -1, -1,
P
,
P
,
P
]
- [
P
,
P
, c, d, -1, -1,
P
,
P
]
- [
P
,
P
,
P
, e, f, -1, -1,
P
]
- [
P
,
P
,
P
,
P
, g, h, -1, -1]
- [
B
, a, b, -1, -1,
E
,
E
,
E
]
- [
B
,
B
, c, d, -1, -1,
E
,
E
]
- [
B
,
B
,
B
, e, f, -1, -1,
E
]
- [
B
,
B
,
B
,
B
, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
...
...
@@ -1147,14 +1150,16 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
delay_pattern
=
torch
.
triu
(
eos_
delay_pattern
=
torch
.
triu
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
),
diagonal
=
max_length
-
num_codebooks
+
1
)
# then fill the lower triangular part (the BOS padding)
delay_pattern
=
delay_pattern
+
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
bos_
delay_pattern
=
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
mask
=
~
delay_pattern
.
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
mask
*
pad_token_id
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
eos_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
...
...
@@ -1334,7 +1339,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
pad_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
decoder_start_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
max_length
=
generation_config
.
max_length
,
)
...
...
@@ -1436,9 +1442,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
# revert the pattern delay mask by filtering the
pad token id
output_ids
=
output_ids
[
output_ids
!=
generation_config
.
pad
_token_id
].
reshape
(
batch_size
,
self
.
num_codebooks
,
-
1
# revert the pattern delay mask by filtering the
eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
eos
_token_id
)
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
if
generation_config
.
return_dict_in_generate
:
...
...
@@ -1919,7 +1925,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
prompt_hidden_states
is
None
:
if
prompt_input_ids
is
not
None
:
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
# TODO:
do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states?
# TODO:
verify prompt_attention_mask
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
decoder_input_ids
=
shift_tokens_right
(
...
...
@@ -1999,7 +2005,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
decoder_delay_pattern_mask
is
None
:
decoder_input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
decoder_input_ids
,
self
.
generation_config
.
pad_token_id
,
bos_token_id
=
self
.
generation_config
.
decoder_start_token_id
,
eos_token_id
=
self
.
generation_config
.
eos_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
)
...
...
@@ -2428,7 +2435,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
pad_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
decoder_start_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
max_length
=
generation_config
.
max_length
,
)
# stash the delay mask so that we don't have to recompute in each forward pass
...
...
@@ -2531,8 +2539,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
# revert the pattern delay mask by filtering the
pad token id
output_ids
=
output_ids
[
output_ids
!=
generation_config
.
pad
_token_id
].
reshape
(
# revert the pattern delay mask by filtering the
eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
eos
_token_id
)
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
...
...
@@ -2543,10 +2551,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
audio_scales
is
None
:
audio_scales
=
[
None
]
*
batch_size
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
audio_scales
=
audio_scales
,
).
audio_values
decode_in_batch
=
((
output_ids
==
generation_config
.
bos_token_id
).
sum
()
+
(
output_ids
==
generation_config
.
eos_token_id
).
sum
())
==
0
if
decode_in_batch
.
item
():
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
audio_scales
=
audio_scales
,
).
audio_values
else
:
output_values
=
[]
for
sample_id
in
range
(
batch_size
):
sample
=
output_ids
[:,
sample_id
]
sample_mask
=
(((
sample
==
generation_config
.
bos_token_id
)
|
(
sample
==
generation_config
.
eos_token_id
)).
sum
(
dim
=
(
0
,
1
))
==
0
)
sample
=
sample
[:,
:,
sample_mask
]
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
if
generation_config
.
return_dict_in_generate
:
outputs
.
sequences
=
output_values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment