Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
11fcc066
Commit
11fcc066
authored
Feb 23, 2024
by
Yoach Lacombe
Browse files
working training + generation
parent
997bf5e6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
54 deletions
+88
-54
example_configs/librispeech_tts_r.json
example_configs/librispeech_tts_r.json
+4
-4
init_dummy_model.py
init_dummy_model.py
+2
-1
run_stable_speech_training.py
run_stable_speech_training.py
+28
-16
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+4
-4
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+50
-29
No files found.
example_configs/librispeech_tts_r.json
View file @
11fcc066
...
@@ -36,14 +36,14 @@
...
@@ -36,14 +36,14 @@
"preprocessing_num_workers"
:
1
,
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
204
8
,
"pad_token_id"
:
204
9
,
"decoder_start_token_id"
:
2048
,
"decoder_start_token_id"
:
2048
,
"do_train"
:
true
,
"do_train"
:
true
,
"num_train_epochs"
:
20
,
"num_train_epochs"
:
1
,
"gradient_accumulation_steps"
:
1
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
true
,
"gradient_checkpointing"
:
true
,
"per_device_train_batch_size"
:
16
,
"per_device_train_batch_size"
:
8
,
"learning_rate"
:
1e-6
,
"learning_rate"
:
1e-6
,
"adam_beta1"
:
0.9
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.95
,
"adam_beta2"
:
0.95
,
...
@@ -56,7 +56,7 @@
...
@@ -56,7 +56,7 @@
"predict_with_generate"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"epoch"
,
"evaluation_strategy"
:
"epoch"
,
"per_device_eval_batch_size"
:
16
,
"per_device_eval_batch_size"
:
8
,
"generation_max_length"
:
400
,
"generation_max_length"
:
400
,
"fp16"
:
true
,
"fp16"
:
true
,
...
...
init_dummy_model.py
View file @
11fcc066
...
@@ -35,7 +35,8 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
...
@@ -35,7 +35,8 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
2048
model
.
generation_config
.
decoder_start_token_id
=
2048
model
.
generation_config
.
pad_token_id
=
2048
model
.
generation_config
.
pad_token_id
=
2049
model
.
generation_config
.
eos_token_id
=
2049
# set other default generation config params
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
...
...
run_stable_speech_training.py
View file @
11fcc066
...
@@ -77,12 +77,6 @@ def list_field(default=None, metadata=None):
...
@@ -77,12 +77,6 @@ def list_field(default=None, metadata=None):
class
StableSpeechTrainer
(
Seq2SeqTrainer
):
class
StableSpeechTrainer
(
Seq2SeqTrainer
):
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
):
def
_pad_tensors_to_max_len
(
self
,
tensor
,
max_length
):
if
self
.
tokenizer
is
not
None
and
hasattr
(
self
.
tokenizer
,
"pad_token_id"
):
# If PAD token is not defined at least EOS token has to be defined
pad_token_id
=
(
self
.
tokenizer
.
pad_token_id
if
self
.
tokenizer
.
pad_token_id
is
not
None
else
self
.
tokenizer
.
eos_token_id
)
else
:
if
self
.
model
.
config
.
pad_token_id
is
not
None
:
if
self
.
model
.
config
.
pad_token_id
is
not
None
:
pad_token_id
=
self
.
model
.
config
.
pad_token_id
pad_token_id
=
self
.
model
.
config
.
pad_token_id
else
:
else
:
...
@@ -387,6 +381,7 @@ class DataCollatorStableSpeechWithPadding:
...
@@ -387,6 +381,7 @@ class DataCollatorStableSpeechWithPadding:
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
# TODO: check it's been padded on the left
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
batch
[
"prompt_input_ids"
]
=
prompt_input_ids
[
"input_ids"
]
if
"attention_mask"
in
prompt_input_ids
:
if
"attention_mask"
in
prompt_input_ids
:
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
batch
[
"prompt_attention_mask"
]
=
prompt_input_ids
[
"attention_mask"
]
...
@@ -676,6 +671,7 @@ def main():
...
@@ -676,6 +671,7 @@ def main():
)
)
# update pad token id and decoder_start_token_id
# update pad token id and decoder_start_token_id
# TODO: verify if this makes sense, maybe should do it for model.decoder
config
.
update
({
config
.
update
({
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
model
.
config
.
pad_token_id
,
"pad_token_id"
:
model_args
.
pad_token_id
if
model_args
.
pad_token_id
is
not
None
else
model
.
config
.
pad_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
model
.
config
.
decoder_start_token_id
,
"decoder_start_token_id"
:
model_args
.
decoder_start_token_id
if
model_args
.
decoder_start_token_id
is
not
None
else
model
.
config
.
decoder_start_token_id
,
...
@@ -700,6 +696,7 @@ def main():
...
@@ -700,6 +696,7 @@ def main():
token
=
data_args
.
token
,
token
=
data_args
.
token
,
trust_remote_code
=
data_args
.
trust_remote_code
,
trust_remote_code
=
data_args
.
trust_remote_code
,
use_fast
=
model_args
.
use_fast_tokenizer
,
use_fast
=
model_args
.
use_fast_tokenizer
,
padding_side
=
"left"
,
# prompt has to be padded on the left bc it's preprend to codebooks hidden states
)
)
# load description tokenizer
# load description tokenizer
...
@@ -740,6 +737,10 @@ def main():
...
@@ -740,6 +737,10 @@ def main():
description_column_name
=
data_args
.
description_column_name
description_column_name
=
data_args
.
description_column_name
prompt_column_name
=
data_args
.
prompt_column_name
prompt_column_name
=
data_args
.
prompt_column_name
feature_extractor_input_name
=
feature_extractor
.
model_input_names
[
0
]
feature_extractor_input_name
=
feature_extractor
.
model_input_names
[
0
]
audio_encoder_eos_token_id
=
config
.
decoder
.
pad_token_id
audio_encoder_bos_token_id
=
model
.
generation_config
.
decoder_start_token_id
max_length
=
model
.
generation_config
.
max_length
num_codebooks
=
model
.
decoder
.
config
.
num_codebooks
# resample target audio
# resample target audio
raw_datasets
=
raw_datasets
.
cast_column
(
raw_datasets
=
raw_datasets
.
cast_column
(
...
@@ -794,7 +795,6 @@ def main():
...
@@ -794,7 +795,6 @@ def main():
# no need to prepare audio_decoder because used for inference without mixed precision
# no need to prepare audio_decoder because used for inference without mixed precision
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
# TODO: load another model
audio_decoder
=
model
.
audio_encoder
audio_decoder
=
model
.
audio_encoder
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
encoder_data_collator
=
DataCollatorEncodecWithPadding
(
feature_extractor
,
feature_extractor_input_name
)
...
@@ -832,18 +832,28 @@ def main():
...
@@ -832,18 +832,28 @@ def main():
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_ratios
.
extend
(
generate_labels
[
"ratio"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
all_lens
.
extend
(
generate_labels
[
"len_audio"
].
cpu
())
# (1, codebooks, seq_len) where seq_len=1
eos_labels
=
torch
.
ones
((
1
,
num_codebooks
,
1
))
*
audio_encoder_eos_token_id
def
postprocess_dataset
(
sample
,
idx
):
def
postprocess_dataset
(
sample
,
idx
):
# (1,
seq_len,
codebooks,
bsz
)
# (1, codebooks,
seq_len
)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
labels
=
all_generated_labels
[
idx
].
transpose
(
0
,
1
).
unsqueeze
(
0
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
labels
=
labels
[:,
:,
:
len_
]
# add eos token column
labels
=
torch
.
cat
([
labels
,
eos_labels
.
to
(
labels
.
device
).
to
(
labels
.
dtype
)],
dim
=-
1
)
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
model
.
generation_config
.
decoder_start_token_id
,
audio_encoder_bos_token_id
,
model
.
generation_config
.
max_length
+
model
.
decoder
.
config
.
num_codebooks
)
audio_encoder_eos_token_id
,
max_length
+
num_codebooks
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
# the first timestamp is associated to a row full of BOS, let's get rid of it
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample
[
"labels"
]
=
labels
[:,
1
:
len_
]
sample
[
"labels"
]
=
labels
[:,
1
:]
return
sample
return
sample
# TODO: done multiple times, how to deal with it.
# TODO: done multiple times, how to deal with it.
...
@@ -956,7 +966,7 @@ def main():
...
@@ -956,7 +966,7 @@ def main():
"""Custom WandbCallback to log model predictions during training.
"""Custom WandbCallback to log model predictions during training.
"""
"""
def
__init__
(
self
,
trainer
,
val_dataset
,
def
__init__
(
self
,
trainer
,
val_dataset
,
description_tokenizer
,
# TODO: add
num_samples
=
8
):
num_samples
=
8
):
"""Initializes the WandbPredictionProgressCallback instance.
"""Initializes the WandbPredictionProgressCallback instance.
...
@@ -969,6 +979,7 @@ def main():
...
@@ -969,6 +979,7 @@ def main():
"""
"""
super
().
__init__
()
super
().
__init__
()
self
.
trainer
=
trainer
self
.
trainer
=
trainer
self
.
description_tokenizer
=
description_tokenizer
self
.
sample_dataset
=
val_dataset
.
select
(
range
(
num_samples
))
self
.
sample_dataset
=
val_dataset
.
select
(
range
(
num_samples
))
def
on_train_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
def
on_train_end
(
self
,
args
,
state
,
control
,
**
kwargs
):
...
@@ -992,6 +1003,7 @@ def main():
...
@@ -992,6 +1003,7 @@ def main():
progress_callback
=
WandbPredictionProgressCallback
(
progress_callback
=
WandbPredictionProgressCallback
(
trainer
=
trainer
,
trainer
=
trainer
,
val_dataset
=
vectorized_datasets
[
"eval"
],
val_dataset
=
vectorized_datasets
[
"eval"
],
description_tokenizer
=
description_tokenizer
,
num_samples
=
8
,
# TODO: add to args
num_samples
=
8
,
# TODO: add to args
)
)
...
...
stable_speech/configuration_stable_speech.py
View file @
11fcc066
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to 204
8
):
vocab_size (`int`, *optional*, defaults to 204
9
):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
hidden_size (`int`, *optional*, defaults to 1024):
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
=
204
8
,
vocab_size
=
204
9
,
# vocab size = 2048 (encodec vocab size) + 1 (eos token)
max_position_embeddings
=
2048
,
max_position_embeddings
=
2048
,
num_hidden_layers
=
24
,
num_hidden_layers
=
24
,
ffn_dim
=
4096
,
ffn_dim
=
4096
,
...
@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor
=
0.02
,
initializer_factor
=
0.02
,
scale_embedding
=
False
,
scale_embedding
=
False
,
num_codebooks
=
4
,
num_codebooks
=
4
,
pad_token_id
=
204
8
,
pad_token_id
=
204
9
,
bos_token_id
=
2048
,
bos_token_id
=
2048
,
eos_token_id
=
None
,
eos_token_id
=
2049
,
tie_word_embeddings
=
False
,
tie_word_embeddings
=
False
,
**
kwargs
,
**
kwargs
,
):
):
...
...
stable_speech/modeling_stable_speech.py
View file @
11fcc066
...
@@ -731,7 +731,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
...
@@ -731,7 +731,7 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if
prompt_hidden_states
is
not
None
:
if
prompt_hidden_states
is
not
None
:
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
# TODO: verify if prompt attention mask is required
# TODO: verify if prompt attention mask is required
and has to be
# As it is, the masked ids from the prompt will still count in the positions embeddings
# As it is, the masked ids from the prompt will still count in the positions embeddings
if
prompt_attention_mask
is
not
None
and
attention_mask
is
not
None
:
if
prompt_attention_mask
is
not
None
and
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
attention_mask
],
dim
=
1
)
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
attention_mask
],
dim
=
1
)
...
@@ -754,6 +754,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
...
@@ -754,6 +754,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
)
)
# embed positions
# embed positions
# TODO: As it is, the masked ids from the prompt will still count in the positions embeddings
# maybe should modify position embeddings
positions
=
self
.
embed_positions
(
inputs_embeds
,
past_key_values_length
)
positions
=
self
.
embed_positions
(
inputs_embeds
,
past_key_values_length
)
hidden_states
=
inputs_embeds
+
positions
.
to
(
inputs_embeds
.
device
)
hidden_states
=
inputs_embeds
+
positions
.
to
(
inputs_embeds
.
device
)
...
@@ -1064,7 +1066,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1064,7 +1066,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
if
delay_pattern_mask
is
None
:
if
delay_pattern_mask
is
None
:
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
input_ids
,
pad_token_id
=
self
.
generation_config
.
pad_token_id
,
bos_token_id
=
self
.
generation_config
.
decoder_start_token_id
,
eos_token_id
=
self
.
generation_config
.
eos_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
max_length
=
self
.
generation_config
.
max_length
,
)
)
...
@@ -1108,22 +1111,22 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1108,22 +1111,22 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
}
}
# Ignore copy
# Ignore copy
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
pad
_token_id
:
int
,
max_length
:
int
=
None
):
def
build_delay_pattern_mask
(
self
,
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
eos
_token_id
:
int
,
max_length
:
int
=
None
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
seq_len)`:
- [
P
, -1, -1, -1, -1,
P
,
P
,
P
]
- [
B
, -1, -1, -1, -1,
E
,
E
,
E
]
- [
P
,
P
, -1, -1, -1, -1,
P
,
P
]
- [
B
,
B
, -1, -1, -1, -1,
E
,
E
]
- [
P
,
P
,
P
, -1, -1, -1, -1,
P
]
- [
B
,
B
,
B
, -1, -1, -1, -1,
E
]
- [
P
,
P
,
P
,
P
, -1, -1, -1, -1]
- [
B
,
B
,
B
,
B
, -1, -1, -1, -1]
where
P
is the
special padding
token id and -1 indicates that the token is valid for prediction. If we include
where
B
is the
BOS token id, E is the EOS
token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
mask is set to the value in the prompt:
- [
P
, a, b, -1, -1,
P
,
P
,
P
]
- [
B
, a, b, -1, -1,
E
,
E
,
E
]
- [
P
,
P
, c, d, -1, -1,
P
,
P
]
- [
B
,
B
, c, d, -1, -1,
E
,
E
]
- [
P
,
P
,
P
, e, f, -1, -1,
P
]
- [
B
,
B
,
B
, e, f, -1, -1,
E
]
- [
P
,
P
,
P
,
P
, g, h, -1, -1]
- [
B
,
B
,
B
,
B
, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
tokens in our prediction.
"""
"""
...
@@ -1147,14 +1150,16 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1147,14 +1150,16 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
# first fill the upper triangular part (the EOS padding)
delay_pattern
=
torch
.
triu
(
eos_
delay_pattern
=
torch
.
triu
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
),
diagonal
=
max_length
-
num_codebooks
+
1
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
),
diagonal
=
max_length
-
num_codebooks
+
1
)
)
# then fill the lower triangular part (the BOS padding)
# then fill the lower triangular part (the BOS padding)
delay_pattern
=
delay_pattern
+
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
bos_
delay_pattern
=
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
mask
=
~
delay_pattern
.
to
(
input_ids
.
device
)
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
mask
*
pad_token_id
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
eos_token_id
# find the first position to start generating - this is the first place we have the -1 token
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
# and will always be in the first codebook (since it has no codebook offset)
...
@@ -1334,7 +1339,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1334,7 +1339,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# Build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
delay_pattern_mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
input_ids
,
pad_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
decoder_start_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
max_length
=
generation_config
.
max_length
,
max_length
=
generation_config
.
max_length
,
)
)
...
@@ -1436,9 +1442,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1436,9 +1442,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# apply the pattern mask to the final ids
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
# revert the pattern delay mask by filtering the
pad token id
# revert the pattern delay mask by filtering the
eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[
output_ids
!=
generation_config
.
pad
_token_id
].
reshape
(
output_ids
=
output_ids
[
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
eos
_token_id
)
].
reshape
(
batch_size
,
self
.
num_codebooks
,
-
1
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
)
if
generation_config
.
return_dict_in_generate
:
if
generation_config
.
return_dict_in_generate
:
...
@@ -1919,7 +1925,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1919,7 +1925,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
prompt_hidden_states
is
None
:
if
prompt_hidden_states
is
None
:
if
prompt_input_ids
is
not
None
:
if
prompt_input_ids
is
not
None
:
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
# TODO:
do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states?
# TODO:
verify prompt_attention_mask
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
decoder_input_ids
=
shift_tokens_right
(
decoder_input_ids
=
shift_tokens_right
(
...
@@ -1999,7 +2005,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1999,7 +2005,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
decoder_delay_pattern_mask
is
None
:
if
decoder_delay_pattern_mask
is
None
:
decoder_input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
decoder_input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
decoder_input_ids
,
decoder_input_ids
,
self
.
generation_config
.
pad_token_id
,
bos_token_id
=
self
.
generation_config
.
decoder_start_token_id
,
eos_token_id
=
self
.
generation_config
.
eos_token_id
,
max_length
=
self
.
generation_config
.
max_length
,
max_length
=
self
.
generation_config
.
max_length
,
)
)
...
@@ -2428,7 +2435,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2428,7 +2435,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
# build the delay pattern mask for offsetting each codebook prediction by 1 (this behaviour is specific to Stable Speech)
input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
decoder_delay_pattern_mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
input_ids
,
pad_token_id
=
generation_config
.
decoder_start_token_id
,
bos_token_id
=
generation_config
.
decoder_start_token_id
,
eos_token_id
=
generation_config
.
eos_token_id
,
max_length
=
generation_config
.
max_length
,
max_length
=
generation_config
.
max_length
,
)
)
# stash the delay mask so that we don't have to recompute in each forward pass
# stash the delay mask so that we don't have to recompute in each forward pass
...
@@ -2531,8 +2539,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2531,8 +2539,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
# apply the pattern mask to the final ids
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
# revert the pattern delay mask by filtering the
pad token id
# revert the pattern delay mask by filtering the
eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[
output_ids
!=
generation_config
.
pad
_token_id
].
reshape
(
output_ids
=
output_ids
[
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
eos
_token_id
)
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
)
...
@@ -2543,10 +2551,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2543,10 +2551,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
audio_scales
is
None
:
if
audio_scales
is
None
:
audio_scales
=
[
None
]
*
batch_size
audio_scales
=
[
None
]
*
batch_size
decode_in_batch
=
((
output_ids
==
generation_config
.
bos_token_id
).
sum
()
+
(
output_ids
==
generation_config
.
eos_token_id
).
sum
())
==
0
if
decode_in_batch
.
item
():
output_values
=
self
.
audio_encoder
.
decode
(
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
output_ids
,
audio_scales
=
audio_scales
,
audio_scales
=
audio_scales
,
).
audio_values
).
audio_values
else
:
output_values
=
[]
for
sample_id
in
range
(
batch_size
):
sample
=
output_ids
[:,
sample_id
]
sample_mask
=
(((
sample
==
generation_config
.
bos_token_id
)
|
(
sample
==
generation_config
.
eos_token_id
)).
sum
(
dim
=
(
0
,
1
))
==
0
)
sample
=
sample
[:,
:,
sample_mask
]
sample
=
self
.
audio_encoder
.
decode
(
sample
[
None
,
...],
[
audio_scales
[
sample_id
]]).
audio_values
output_values
.
append
(
sample
.
transpose
(
0
,
2
))
# TODO: we should keep track of output length as well. Not really straightfoward tbh
output_values
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
output_values
,
batch_first
=
True
,
padding_value
=
0
).
transpose
(
1
,
2
).
squeeze
(
-
1
)
if
generation_config
.
return_dict_in_generate
:
if
generation_config
.
return_dict_in_generate
:
outputs
.
sequences
=
output_values
outputs
.
sequences
=
output_values
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment