Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
0f6d59d4
Commit
0f6d59d4
authored
Feb 26, 2024
by
Yoach Lacombe
Browse files
latest changes
parent
11fcc066
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
739 additions
and
109 deletions
+739
-109
example_configs/librispeech_tts_r.json
example_configs/librispeech_tts_r.json
+18
-12
example_configs/librispeech_tts_r_dummy.json
example_configs/librispeech_tts_r_dummy.json
+74
-0
init_dummy_model.py
init_dummy_model.py
+10
-10
run_stable_speech_training.py
run_stable_speech_training.py
+573
-52
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+3
-3
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+61
-32
No files found.
example_configs/librispeech_tts_r.json
View file @
0f6d59d4
...
@@ -24,11 +24,11 @@
...
@@ -24,11 +24,11 @@
"description_column_name"
:
"text_description"
,
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"prompt_column_name"
:
"text"
,
"max_train_samples"
:
100
0
,
"max_train_samples"
:
2
0
,
"max_eval_samples"
:
20
0
,
"max_eval_samples"
:
1
0
,
"max_duration_in_seconds"
:
2
0
,
"max_duration_in_seconds"
:
3
0
,
"min_duration_in_seconds"
:
1.0
,
"min_duration_in_seconds"
:
1.0
,
"add_audio_samples_to_wandb"
:
true
,
"add_audio_samples_to_wandb"
:
true
,
...
@@ -36,30 +36,36 @@
...
@@ -36,30 +36,36 @@
"preprocessing_num_workers"
:
1
,
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
20
49
,
"pad_token_id"
:
20
50
,
"decoder_start_token_id"
:
2048
,
"decoder_start_token_id"
:
2048
,
"do_train"
:
true
,
"do_train"
:
true
,
"num_train_epochs"
:
1
,
"num_train_epochs"
:
1
20
,
"gradient_accumulation_steps"
:
1
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
tru
e
,
"gradient_checkpointing"
:
fals
e
,
"per_device_train_batch_size"
:
8
,
"per_device_train_batch_size"
:
2
,
"learning_rate"
:
1e-
6
,
"learning_rate"
:
1e-
3
,
"adam_beta1"
:
0.9
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.9
5
,
"adam_beta2"
:
0.9
99
,
"weight_decay"
:
0.1
,
"weight_decay"
:
0.1
,
"logging_steps"
:
25
,
"lr_scheduler_type"
:
"cosine"
,
"warmup_ratio"
:
0.1
,
"logging_steps"
:
1
,
"freeze_text_encoder"
:
true
,
"do_eval"
:
true
,
"do_eval"
:
true
,
"predict_with_generate"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"epoch"
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
600
,
"per_device_eval_batch_size"
:
8
,
"per_device_eval_batch_size"
:
8
,
"generation_max_length"
:
400
,
"generation_max_length"
:
400
,
"fp16"
:
tru
e
,
"fp16"
:
fals
e
,
"seed"
:
456
,
"seed"
:
456
,
"dataloader_num_workers"
:
8
"dataloader_num_workers"
:
8
...
...
example_configs/librispeech_tts_r_dummy.json
0 → 100644
View file @
0f6d59d4
{
"model_name_or_path"
:
"/home/yoach/dataspeech/artefacts/tiny-model/"
,
"feature_extractor_name"
:
"facebook/encodec_24khz"
,
"description_tokenizer_name"
:
"t5-base"
,
"prompt_tokenizer_name"
:
"t5-base"
,
"push_to_hub"
:
false
,
"hub_model_id"
:
"stable-speech-mini"
,
"report_to"
:
[
"wandb"
],
"overwrite_output_dir"
:
true
,
"output_dir"
:
"/home/yoach/dataspeech/artefacts/training/"
,
"train_dataset_name"
:
"blabble-io/libritts_r"
,
"train_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated"
,
"train_dataset_config_name"
:
"clean"
,
"train_split_name"
:
"train.clean.360"
,
"eval_dataset_name"
:
"blabble-io/libritts_r"
,
"eval_metadata_dataset_name"
:
"stable-speech/libritts-r-tags-and-text-generated"
,
"eval_dataset_config_name"
:
"clean"
,
"eval_split_name"
:
"train.clean.360"
,
"target_audio_column_name"
:
"audio"
,
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"max_train_samples"
:
12
,
"max_eval_samples"
:
12
,
"max_duration_in_seconds"
:
30
,
"min_duration_in_seconds"
:
1.0
,
"add_audio_samples_to_wandb"
:
true
,
"id_column_name"
:
"id"
,
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
2050
,
"decoder_start_token_id"
:
2048
,
"do_train"
:
true
,
"num_train_epochs"
:
20
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
false
,
"per_device_train_batch_size"
:
3
,
"learning_rate"
:
1e-3
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.999
,
"weight_decay"
:
0.1
,
"lr_scheduler_type"
:
"cosine"
,
"warmup_ratio"
:
0.1
,
"freeze_text_encoder"
:
true
,
"do_eval"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
10
,
"per_device_eval_batch_size"
:
3
,
"generation_max_length"
:
400
,
"do_sample"
:
true
,
"logging_steps"
:
15
,
"dtype"
:
"float32"
,
"seed"
:
456
,
"dataloader_num_workers"
:
8
}
init_dummy_model.py
View file @
0f6d59d4
...
@@ -4,16 +4,16 @@ from transformers import AutoConfig
...
@@ -4,16 +4,16 @@ from transformers import AutoConfig
decoder_config
=
StableSpeechDecoderConfig
(
decoder_config
=
StableSpeechDecoderConfig
(
max_position_embeddings
=
2048
,
max_position_embeddings
=
2048
,
num_hidden_layers
=
2
,
num_hidden_layers
=
4
,
ffn_dim
=
256
,
ffn_dim
=
512
,
num_attention_heads
=
4
,
num_attention_heads
=
8
,
layerdrop
=
0.0
,
layerdrop
=
0.0
,
use_cache
=
True
,
use_cache
=
True
,
activation_function
=
"gelu"
,
activation_function
=
"gelu"
,
hidden_size
=
256
,
hidden_size
=
512
,
dropout
=
0.
1
,
dropout
=
0.
0
,
attention_dropout
=
0.
1
,
attention_dropout
=
0.
0
,
activation_dropout
=
0.
1
,
activation_dropout
=
0.
0
,
)
)
# TODO: ?? how to make it stop ?
# TODO: ?? how to make it stop ?
...
@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
...
@@ -35,12 +35,12 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
2048
model
.
generation_config
.
decoder_start_token_id
=
2048
model
.
generation_config
.
pad_token_id
=
20
49
model
.
generation_config
.
pad_token_id
=
20
50
model
.
generation_config
.
eos_token_id
=
2049
model
.
generation_config
.
eos_token_id
=
2049
# set other default generation config params
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
do_sample
=
True
model
.
generation_config
.
do_sample
=
False
#
True
model
.
generation_config
.
guidance_scale
=
3.0
model
.
generation_config
.
guidance_scale
=
1
#
3.0
model
.
save_pretrained
(
"/home/yoach/dataspeech/artefacts/tiny-model/"
)
model
.
save_pretrained
(
"/home/yoach/dataspeech/artefacts/tiny-model/"
)
\ No newline at end of file
run_stable_speech_training.py
View file @
0f6d59d4
This diff is collapsed.
Click to expand it.
stable_speech/configuration_stable_speech.py
View file @
0f6d59d4
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to 20
49
):
vocab_size (`int`, *optional*, defaults to 20
50
):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
hidden_size (`int`, *optional*, defaults to 1024):
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
=
20
49
,
# vocab size = 2048 (encodec vocab size) +
1
(
e
os
token
)
vocab_size
=
20
50
,
# vocab size = 2048 (encodec vocab size) +
2
(
b
os
, eos
)
max_position_embeddings
=
2048
,
max_position_embeddings
=
2048
,
num_hidden_layers
=
24
,
num_hidden_layers
=
24
,
ffn_dim
=
4096
,
ffn_dim
=
4096
,
...
@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -96,7 +96,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor
=
0.02
,
initializer_factor
=
0.02
,
scale_embedding
=
False
,
scale_embedding
=
False
,
num_codebooks
=
4
,
num_codebooks
=
4
,
pad_token_id
=
20
49
,
pad_token_id
=
20
50
,
bos_token_id
=
2048
,
bos_token_id
=
2048
,
eos_token_id
=
2049
,
eos_token_id
=
2049
,
tie_word_embeddings
=
False
,
tie_word_embeddings
=
False
,
...
...
stable_speech/modeling_stable_speech.py
View file @
0f6d59d4
...
@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
...
@@ -659,7 +659,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
self
.
num_codebooks
=
config
.
num_codebooks
self
.
num_codebooks
=
config
.
num_codebooks
self
.
embed_scale
=
math
.
sqrt
(
config
.
hidden_size
)
if
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
config
.
hidden_size
)
if
config
.
scale_embedding
else
1.0
embed_dim
=
config
.
vocab_size
+
1
# TODO: not right dim
embed_dim
=
config
.
vocab_size
+
1
# + 1 for pad token id
self
.
embed_tokens
=
nn
.
ModuleList
(
self
.
embed_tokens
=
nn
.
ModuleList
(
[
nn
.
Embedding
(
embed_dim
,
config
.
hidden_size
)
for
_
in
range
(
config
.
num_codebooks
)]
[
nn
.
Embedding
(
embed_dim
,
config
.
hidden_size
)
for
_
in
range
(
config
.
num_codebooks
)]
)
)
...
@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -981,6 +982,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
@@ -991,7 +993,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -991,7 +993,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
Returns:
# TODO: delay_pattern_mask
Returns:
"""
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
...
@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1019,17 +1022,36 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits
=
lm_logits
[:,:,
-
labels
.
shape
[
1
]:]
logits
=
lm_logits
[:,:,
-
labels
.
shape
[
1
]:]
loss_fct
=
CrossEntropyLoss
()
loss_fct
=
CrossEntropyLoss
()
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
bos_token_id
,
-
100
)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
pad_token_id
,
-
100
)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
pad_token_id
,
-
100
)
loss
=
loss_fct
(
logits
.
transpose
(
1
,
3
),
labels
)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask
=
(
labels
!=
-
100
)
# per codebook cross-entropy
for
codebook
in
range
(
self
.
config
.
num_codebooks
):
codebook_logits
=
logits
[:,
codebook
].
contiguous
().
view
(
-
1
,
logits
.
shape
[
-
1
])
codebook_mask
=
mask
[...,
codebook
].
contiguous
().
view
(
-
1
)
codebook_labels
=
labels
[...,
codebook
].
contiguous
().
view
(
-
1
)
codebook_loss
=
loss_fct
(
codebook_logits
[
codebook_mask
],
codebook_labels
[
codebook_mask
])
loss
+=
codebook_loss
loss
=
loss
/
self
.
config
.
num_codebooks
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits
=
lm_logits
.
reshape
(
-
1
,
*
lm_logits
.
shape
[
2
:])
lm_logits
=
lm_logits
.
reshape
(
-
1
,
*
lm_logits
.
shape
[
2
:])
...
@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1066,8 +1088,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if
delay_pattern_mask
is
None
:
if
delay_pattern_mask
is
None
:
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
input_ids
,
bos_token_id
=
self
.
generation_config
.
decoder_start
_token_id
,
bos_token_id
=
self
.
generation_config
.
bos
_token_id
,
eos
_token_id
=
self
.
generation_config
.
eos
_token_id
,
pad
_token_id
=
self
.
generation_config
.
pad
_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
max_length
=
self
.
generation_config
.
max_length
,
)
)
...
@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1111,21 +1133,21 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
}
# Ignore copy
# Ignore copy
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
eos
_token_id
:
int
,
max_length
:
int
=
None
):
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad
_token_id
:
int
,
max_length
:
int
=
None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
seq_len)`:
- [B, -1, -1, -1, -1,
E
,
E
,
E
]
- [B, -1, -1, -1, -1,
P
,
P
,
P
]
- [B, B, -1, -1, -1, -1,
E
,
E
]
- [B, B, -1, -1, -1, -1,
P
,
P
]
- [B, B, B, -1, -1, -1, -1,
E
]
- [B, B, B, -1, -1, -1, -1,
P
]
- [B, B, B, B, -1, -1, -1, -1]
- [B, B, B, B, -1, -1, -1, -1]
where
B
is the
BOS token id, E is the EOS
token id and -1 indicates that the token is valid for prediction. If we include
where
P
is the
special padding
token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
mask is set to the value in the prompt:
- [B, a, b, -1, -1,
E
,
E
,
E
]
- [B, a, b, -1, -1,
P
,
P
,
P
]
- [B, B, c, d, -1, -1,
E
,
E
]
- [B, B, c, d, -1, -1,
P
,
P
]
- [B, B, B, e, f, -1, -1,
E
]
- [B, B, B, e, f, -1, -1,
P
]
- [B, B, B, B, g, h, -1, -1]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
tokens in our prediction.
...
@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1159,7 +1181,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
eos
_token_id
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
pad
_token_id
# find the first position to start generating - this is the first place we have the -1 token
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
# and will always be in the first codebook (since it has no codebook offset)
...
@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1339,8 +1361,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
input_ids
,
bos_token_id
=
generation_config
.
decoder_start
_token_id
,
bos_token_id
=
generation_config
.
bos
_token_id
,
eos
_token_id
=
generation_config
.
eos
_token_id
,
pad
_token_id
=
generation_config
.
pad
_token_id
,
max_length
=
generation_config
.
max_length
,
max_length
=
generation_config
.
max_length
,
)
)
...
@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1846,6 +1868,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1928,9 +1951,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# TODO: verify prompt_attention_mask
# TODO: verify prompt_attention_mask
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
# TODO: verify it does what's expected
decoder_input_ids
=
shift_tokens_right
(
decoder_input_ids
=
shift_tokens_right
(
labels
.
transpose
(
1
,
2
)
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
)
.
transpose
(
1
,
2
)
elif
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
elif
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
:
audio_encoder_outputs
=
self
.
audio_encoder
(
audio_encoder_outputs
=
self
.
audio_encoder
(
...
@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1967,6 +1991,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
labels
=
labels
,
label_delay_pattern_mask
=
label_delay_pattern_mask
,
**
kwargs_decoder
,
**
kwargs_decoder
,
)
)
...
@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2005,8 +2030,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
decoder_delay_pattern_mask
is
None
:
if
decoder_delay_pattern_mask
is
None
:
decoder_input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
decoder_input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
decoder_input_ids
,
decoder_input_ids
,
bos_token_id
=
self
.
generation_config
.
decoder_start
_token_id
,
bos_token_id
=
self
.
generation_config
.
bos
_token_id
,
eos
_token_id
=
self
.
generation_config
.
eos
_token_id
,
pad
_token_id
=
self
.
generation_config
.
pad
_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
max_length
=
self
.
generation_config
.
max_length
,
)
)
...
@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2197,7 +2222,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return
model_kwargs
return
model_kwargs
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
def
prepare_decoder_input_ids_from_labels
(
self
,
labels
:
torch
.
Tensor
):
return
shift_tokens_right
(
labels
.
transpose
(
1
,
2
)
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
.
transpose
(
1
,
2
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
# TODO: now it's possible with prompt_embeddings
...
@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2435,8 +2460,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
input_ids
,
bos_token_id
=
generation_config
.
decoder_start
_token_id
,
bos_token_id
=
generation_config
.
bos
_token_id
,
eos
_token_id
=
generation_config
.
eos
_token_id
,
pad
_token_id
=
generation_config
.
pad
_token_id
,
max_length
=
generation_config
.
max_length
,
max_length
=
generation_config
.
max_length
,
)
)
# stash the delay mask so that we don't have to recompute in each forward pass
# stash the delay mask so that we don't have to recompute in each forward pass
...
@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2540,7 +2565,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
eos_token_id
)].
reshape
(
# TODO: probably won't work...
output_ids
=
output_ids
[(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
pad_token_id
)].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
)
...
@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2556,17 +2582,20 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
output_values
=
self
.
audio_encoder
.
decode
(
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
output_ids
,
audio_scales
=
audio_scales
,
audio_scales
=
audio_scales
,
).
audio_values
).
audio_values
.
squeeze
(
1
)
else
:
else
:
output_values
=
[]
output_values
=
[]
for
sample_id
in
range
(
batch_size
):
for
sample_id
in
range
(
batch_size
):
sample
=
output_ids
[:,
sample_id
]
sample
=
output_ids
[:,
sample_id
]
sample_mask
=
(((
sample
==
generation_config
.
bos_token_id
)
|
(
sample
==
generation_config
.
eos_token_id
)).
sum
(
dim
=
(
0
,
1
))
==
0
)
sample_mask
=
(((
sample
==
generation_config
.
bos_token_id
)
|
(
sample
==
generation_config
.
eos_token_id
)
|
(
sample
==
generation_config
.
pad_token_id
)).
sum
(
dim
=
(
0
,
1
))
==
0
)
sample
=
sample
[:,
:,
sample_mask
]
if
sample_mask
.
sum
()
>
0
:
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
sample
=
sample
[:,
:,
sample_mask
]
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
else
:
output_values
.
append
(
torch
.
zeros
((
1
,
1
,
1
)).
to
(
self
.
device
))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
output_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
.
squeeze
(
1
)
if
generation_config
.
return_dict_in_generate
:
if
generation_config
.
return_dict_in_generate
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment