Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
d0140745
Commit
d0140745
authored
Feb 27, 2024
by
Yoach Lacombe
Browse files
update code: fix accelerate, fix delay pattern mask, improve generation
parent
ee12a812
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
140 additions
and
116 deletions
+140
-116
example_configs/librispeech_tts_r.json
example_configs/librispeech_tts_r.json
+2
-2
example_configs/librispeech_tts_r_dummy.json
example_configs/librispeech_tts_r_dummy.json
+9
-9
init_dummy_model.py
init_dummy_model.py
+3
-3
init_model.py
init_model.py
+6
-6
run_stable_speech_training.py
run_stable_speech_training.py
+24
-25
stable_speech/__init__.py
stable_speech/__init__.py
+1
-1
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+5
-5
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+90
-65
No files found.
example_configs/librispeech_tts_r.json
View file @
d0140745
...
@@ -36,8 +36,8 @@
...
@@ -36,8 +36,8 @@
"preprocessing_num_workers"
:
1
,
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
20
50
,
"pad_token_id"
:
20
48
,
"decoder_start_token_id"
:
204
8
,
"decoder_start_token_id"
:
204
9
,
"do_train"
:
true
,
"do_train"
:
true
,
"num_train_epochs"
:
120
,
"num_train_epochs"
:
120
,
...
...
example_configs/librispeech_tts_r_dummy.json
View file @
d0140745
...
@@ -24,8 +24,8 @@
...
@@ -24,8 +24,8 @@
"description_column_name"
:
"text_description"
,
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"prompt_column_name"
:
"text"
,
"max_train_samples"
:
12
,
"max_train_samples"
:
4
,
"max_eval_samples"
:
12
,
"max_eval_samples"
:
4
,
"max_duration_in_seconds"
:
30
,
"max_duration_in_seconds"
:
30
,
...
@@ -36,14 +36,14 @@
...
@@ -36,14 +36,14 @@
"preprocessing_num_workers"
:
1
,
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
20
50
,
"pad_token_id"
:
20
48
,
"decoder_start_token_id"
:
204
8
,
"decoder_start_token_id"
:
204
9
,
"do_train"
:
true
,
"do_train"
:
true
,
"num_train_epochs"
:
20
,
"num_train_epochs"
:
1
20
,
"gradient_accumulation_steps"
:
1
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
false
,
"gradient_checkpointing"
:
false
,
"per_device_train_batch_size"
:
3
,
"per_device_train_batch_size"
:
2
,
"learning_rate"
:
1e-3
,
"learning_rate"
:
1e-3
,
"adam_beta1"
:
0.9
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.999
,
"adam_beta2"
:
0.999
,
...
@@ -60,10 +60,10 @@
...
@@ -60,10 +60,10 @@
"predict_with_generate"
:
true
,
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"steps"
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
1
0
,
"eval_steps"
:
3
0
,
"per_device_eval_batch_size"
:
3
,
"per_device_eval_batch_size"
:
2
,
"generation_max_length"
:
400
,
"generation_max_length"
:
400
,
"do_sample"
:
tru
e
,
"do_sample"
:
fals
e
,
"logging_steps"
:
15
,
"logging_steps"
:
15
,
...
...
init_dummy_model.py
View file @
d0140745
...
@@ -34,9 +34,9 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
...
@@ -34,9 +34,9 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
)
)
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
204
8
model
.
generation_config
.
decoder_start_token_id
=
204
9
model
.
generation_config
.
pad_token_id
=
20
50
model
.
generation_config
.
pad_token_id
=
20
48
model
.
generation_config
.
eos_token_id
=
204
9
model
.
generation_config
.
eos_token_id
=
204
8
# set other default generation config params
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
...
...
init_model.py
View file @
d0140745
...
@@ -18,7 +18,7 @@ decoder_config = StableSpeechDecoderConfig(
...
@@ -18,7 +18,7 @@ decoder_config = StableSpeechDecoderConfig(
decoder
=
StableSpeechForCausalLM
(
decoder_config
)
decoder
=
StableSpeechForCausalLM
(
decoder_config
)
decoder
.
save_pretrained
(
"/
raid
/yoach/
tmp
/decoder/"
)
decoder
.
save_pretrained
(
"/
home
/yoach/
dataspeech/artefacts
/decoder/"
)
t5
=
AutoConfig
.
from_pretrained
(
"t5-base"
)
t5
=
AutoConfig
.
from_pretrained
(
"t5-base"
)
...
@@ -26,18 +26,18 @@ t5 = AutoConfig.from_pretrained("t5-base")
...
@@ -26,18 +26,18 @@ t5 = AutoConfig.from_pretrained("t5-base")
model
=
StableSpeechForConditionalGeneration
.
from_sub_models_pretrained
(
model
=
StableSpeechForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
"t5-base"
,
text_encoder_pretrained_model_name_or_path
=
"t5-base"
,
audio_encoder_pretrained_model_name_or_path
=
"facebook/encodec_32khz"
,
audio_encoder_pretrained_model_name_or_path
=
"facebook/encodec_32khz"
,
decoder_pretrained_model_name_or_path
=
"/
raid
/yoach/
tmp
/decoder/"
,
decoder_pretrained_model_name_or_path
=
"/
home
/yoach/
dataspeech/artefacts
/decoder/"
,
vocab_size
=
t5
.
vocab_size
vocab_size
=
t5
.
vocab_size
)
)
# set the appropriate bos/pad token ids
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
204
8
model
.
generation_config
.
decoder_start_token_id
=
204
9
model
.
generation_config
.
pad_token_id
=
20
50
model
.
generation_config
.
pad_token_id
=
20
48
model
.
generation_config
.
eos_token_id
=
204
9
model
.
generation_config
.
eos_token_id
=
204
8
# set other default generation config params
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
do_sample
=
False
# True
model
.
generation_config
.
do_sample
=
False
# True
model
.
generation_config
.
guidance_scale
=
1
# 3.0
model
.
generation_config
.
guidance_scale
=
1
# 3.0
model
.
save_pretrained
(
"/raid/yoach/tmp/small-stable-speech-untrained/"
)
model
.
save_pretrained
(
"/home/yoach/dataspeech/artefacts/small-stable-speech-untrained/"
)
\ No newline at end of file
\ No newline at end of file
run_stable_speech_training.py
View file @
d0140745
...
@@ -26,6 +26,8 @@ import shutil
...
@@ -26,6 +26,8 @@ import shutil
import
warnings
import
warnings
import
math
import
math
import
time
import
time
from
multiprocess
import
set_start_method
import
evaluate
import
evaluate
from
tqdm
import
tqdm
from
tqdm
import
tqdm
...
@@ -63,7 +65,7 @@ from accelerate import Accelerator
...
@@ -63,7 +65,7 @@ from accelerate import Accelerator
from
accelerate.utils
import
set_seed
from
accelerate.utils
import
set_seed
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
if
is_wandb_available
():
if
is_wandb_available
():
from
wandb
import
Audio
from
wandb
import
Audio
...
@@ -516,15 +518,10 @@ class DataCollatorStableSpeechWithPadding:
...
@@ -516,15 +518,10 @@ class DataCollatorStableSpeechWithPadding:
# (bsz, seq_len, num_codebooks)
# (bsz, seq_len, num_codebooks)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=-
100
)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=-
100
)
delay_pattern_mask
=
[
torch
.
tensor
(
feature
[
"label_delay_pattern_mask"
]).
transpose
(
0
,
1
)
for
feature
in
features
]
# (bsz, seq_len, num_codebooks)
delay_pattern_mask
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
delay_pattern_mask
,
batch_first
=
True
,
padding_value
=-
100
)
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
{
"labels"
:
labels
,
"label_delay_pattern_mask"
:
delay_pattern_mask
,
**
input_ids
}
batch
=
{
"labels"
:
labels
,
**
input_ids
}
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
...
@@ -1014,23 +1011,30 @@ def main():
...
@@ -1014,23 +1011,30 @@ def main():
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
labels
=
labels
[:,
:,
:
len_
]
labels
=
labels
[:,
:,
:
len_
]
# TODO: remove, only for test
labels
=
labels
[:,
:,
:(
len_
)
%
10
+
20
]
# TODO: change
labels
=
labels
[:,
:,
:(
len_
)
%
10
+
20
]
# add bos
and eos token column
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
,
eos_labels
.
to
(
labels
.
device
).
to
(
labels
.
dtype
)
],
dim
=-
1
)
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
labels
,
delay_pattern_mask
=
build_delay_pattern_mask
(
labels
,
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
bos_token_id
=
audio_encoder_bos_token_id
,
bos_token_id
=
audio_encoder_bos_token_id
,
pad_token_id
=
audio_encoder_pad_token_id
,
pad_token_id
=
audio_encoder_eos_token_id
,
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
)
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
,
num_codebooks
=
num_codebooks
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels
=
torch
.
where
(
delay_pattern_mask
==-
1
,
audio_encoder_eos_token_id
,
delay_pattern_mask
)
# the first timestamp is associated to a row full of BOS, let's get rid of it
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample
[
"labels"
]
=
labels
[:,
1
:]
# we also remove the last timestampts (full of PAD)
sample
[
"label
_delay_pattern_mask"
]
=
delay_pattern_mask
[:,
1
:]
sample
[
"label
s"
]
=
labels
[:,
1
:]
.
cpu
()
return
sample
return
sample
# TODO: done multiple times, how to deal with it.
# TODO: done multiple times, how to deal with it.
...
@@ -1047,12 +1051,6 @@ def main():
...
@@ -1047,12 +1051,6 @@ def main():
del
generate_labels
del
generate_labels
if
data_args
.
add_audio_samples_to_wandb
and
"wandb"
in
training_args
.
report_to
:
if
is_wandb_available
():
from
transformers.integrations
import
WandbCallback
else
:
raise
ValueError
(
"`args.add_audio_samples_to_wandb=True` but wandb is not installed. See https://docs.wandb.ai/quickstart to install."
)
# for large datasets it is advised to run the preprocessing on a
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
# single machine first with ``args.preprocessing_only`` since there will mostly likely
...
@@ -1467,4 +1465,5 @@ def main():
...
@@ -1467,4 +1465,5 @@ def main():
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
set_start_method
(
"spawn"
)
main
()
main
()
\ No newline at end of file
stable_speech/__init__.py
View file @
d0140745
from
.configuration_stable_speech
import
StableSpeechConfig
,
StableSpeechDecoderConfig
from
.configuration_stable_speech
import
StableSpeechConfig
,
StableSpeechDecoderConfig
from
.modeling_stable_speech
import
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
from
.modeling_stable_speech
import
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
\ No newline at end of file
\ No newline at end of file
stable_speech/configuration_stable_speech.py
View file @
d0140745
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to 20
50
):
vocab_size (`int`, *optional*, defaults to 20
49
):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
hidden_size (`int`, *optional*, defaults to 1024):
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def
__init__
(
def
__init__
(
self
,
self
,
vocab_size
=
20
50
,
# vocab size = 2048 (encodec vocab size) +
2
(
bos,
eos)
vocab_size
=
20
49
,
# vocab size = 2048 (encodec vocab size) +
1
(eos)
max_position_embeddings
=
2048
,
max_position_embeddings
=
2048
,
num_hidden_layers
=
24
,
num_hidden_layers
=
24
,
ffn_dim
=
4096
,
ffn_dim
=
4096
,
...
@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
...
@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor
=
0.02
,
initializer_factor
=
0.02
,
scale_embedding
=
False
,
scale_embedding
=
False
,
num_codebooks
=
4
,
num_codebooks
=
4
,
pad_token_id
=
20
50
,
pad_token_id
=
20
48
,
bos_token_id
=
204
8
,
bos_token_id
=
204
9
,
eos_token_id
=
204
9
,
eos_token_id
=
204
8
,
tie_word_embeddings
=
False
,
tie_word_embeddings
=
False
,
**
kwargs
,
**
kwargs
,
):
):
...
...
stable_speech/modeling_stable_speech.py
View file @
d0140745
...
@@ -60,6 +60,77 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
...
@@ -60,6 +60,77 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
]
]
def
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len
=
input_ids
.
shape
[
-
1
]
decoder_pad_token_mask
=
decoder_pad_token_mask
[...,
:
seq_len
]
input_ids
=
torch
.
where
(
decoder_pad_token_mask
==
-
1
,
input_ids
,
decoder_pad_token_mask
)
return
input_ids
def
build_delay_pattern_mask
(
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad_token_id
:
int
,
max_length
:
int
,
num_codebooks
:
int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids
=
input_ids
.
reshape
(
-
1
,
num_codebooks
,
input_ids
.
shape
[
-
1
])
bsz
,
num_codebooks
,
seq_len
=
input_ids
.
shape
input_ids_shifted
=
(
torch
.
ones
((
bsz
,
num_codebooks
,
max_length
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
-
1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if
max_length
<
2
*
num_codebooks
-
1
:
return
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
),
input_ids_shifted
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for
codebook
in
range
(
num_codebooks
):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted
[:,
codebook
,
codebook
:
seq_len
+
codebook
]
=
input_ids
[:,
codebook
]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern
=
torch
.
triu
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
),
diagonal
=
max_length
-
num_codebooks
+
1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern
=
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids
=
input_ids
[:,
0
,
:]
start_ids
=
(
first_codebook_ids
==
-
1
).
nonzero
()[:,
1
]
if
len
(
start_ids
)
>
0
:
first_start_id
=
min
(
start_ids
)
else
:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id
=
seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask
=
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
input_ids
=
input_ids
[...,
:
first_start_id
].
reshape
(
bsz
*
num_codebooks
,
-
1
)
return
input_ids
,
pattern_mask
@
dataclass
@
dataclass
class
StableSpeechUnconditionalInput
(
ModelOutput
):
class
StableSpeechUnconditionalInput
(
ModelOutput
):
...
@@ -982,7 +1053,6 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -982,7 +1053,6 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
@@ -1031,16 +1101,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1031,16 +1101,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
bos_token_id
,
-
100
)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
bos_token_id
,
-
100
)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
pad_token_id
,
-
100
)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
# we use every codebooks token AND one single EOS at the end of each codebooks
mask
=
(
input_ids
.
transpose
(
1
,
2
)
!=
self
.
config
.
eos_token_id
)
&
((
labels
!=
-
100
))
mask
=
(
labels
!=
-
100
)
# per codebook cross-entropy
# per codebook cross-entropy
for
codebook
in
range
(
self
.
config
.
num_codebooks
):
for
codebook
in
range
(
self
.
config
.
num_codebooks
):
...
@@ -1152,60 +1215,14 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1152,60 +1215,14 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
tokens in our prediction.
"""
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids
=
input_ids
.
reshape
(
-
1
,
self
.
num_codebooks
,
input_ids
.
shape
[
-
1
])
bsz
,
num_codebooks
,
seq_len
=
input_ids
.
shape
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_config
.
max_length
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_config
.
max_length
input_ids_shifted
=
(
return
build_delay_pattern_mask
(
input_ids
,
bos_token_id
,
pad_token_id
,
max_length
,
self
.
num_codebooks
)
torch
.
ones
((
bsz
,
num_codebooks
,
max_length
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
-
1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if
max_length
<
2
*
num_codebooks
-
1
:
return
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
),
input_ids_shifted
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for
codebook
in
range
(
num_codebooks
):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted
[:,
codebook
,
codebook
:
seq_len
+
codebook
]
=
input_ids
[:,
codebook
]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern
=
torch
.
triu
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
),
diagonal
=
max_length
-
num_codebooks
+
1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern
=
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids
=
input_ids
[:,
0
,
:]
start_ids
=
(
first_codebook_ids
==
-
1
).
nonzero
()[:,
1
]
if
len
(
start_ids
)
>
0
:
first_start_id
=
min
(
start_ids
)
else
:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id
=
seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask
=
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
input_ids
=
input_ids
[...,
:
first_start_id
].
reshape
(
bsz
*
num_codebooks
,
-
1
)
return
input_ids
,
pattern_mask
@
staticmethod
@
staticmethod
def
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
):
def
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len
=
input_ids
.
shape
[
-
1
]
return
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
)
decoder_pad_token_mask
=
decoder_pad_token_mask
[...,
:
seq_len
]
input_ids
=
torch
.
where
(
decoder_pad_token_mask
==
-
1
,
input_ids
,
decoder_pad_token_mask
)
return
input_ids
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
generate
(
def
generate
(
...
@@ -1219,7 +1236,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1219,7 +1236,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
**
kwargs
,
**
kwargs
,
):
):
"""
"""
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head.
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
<Tip warning={true}>
...
@@ -1868,7 +1885,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1868,7 +1885,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
@@ -1991,7 +2007,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1991,7 +2007,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
labels
=
labels
,
label_delay_pattern_mask
=
label_delay_pattern_mask
,
**
kwargs_decoder
,
**
kwargs_decoder
,
)
)
...
@@ -2074,6 +2089,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2074,6 +2089,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"head_mask"
:
head_mask
,
"head_mask"
:
head_mask
,
"decoder_head_mask"
:
decoder_head_mask
,
"decoder_head_mask"
:
decoder_head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"prompt_hidden_states"
:
prompt_hidden_states
,
"prompt_attention_mask"
:
prompt_attention_mask
,
"use_cache"
:
use_cache
,
"use_cache"
:
use_cache
,
}
}
...
@@ -2564,9 +2581,17 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2564,9 +2581,17 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
# apply the pattern mask to the final ids
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# TODO: probably won't work...
_
,
mask
=
self
.
decoder
.
build_delay_pattern_mask
(
output_ids
=
output_ids
[(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
pad_token_id
)].
reshape
(
input_ids
,
bos_token_id
=
generation_config
.
bos_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
max_length
=
output_ids
.
shape
[
1
],
)
mask
=
(
mask
!=
generation_config
.
bos_token_id
)
&
(
mask
!=
generation_config
.
pad_token_id
)
output_ids
=
output_ids
[
mask
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
)
...
@@ -2577,8 +2602,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2577,8 +2602,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
audio_scales
is
None
:
if
audio_scales
is
None
:
audio_scales
=
[
None
]
*
batch_size
audio_scales
=
[
None
]
*
batch_size
decode_
in_batch
=
((
output_ids
==
generation_config
.
bos
_token_id
).
sum
()
+
(
output_ids
==
generation_config
.
eos_token_id
).
sum
())
==
0
decode_
sequentially
=
generation_config
.
bos_token_id
in
output_ids
or
generation_config
.
pad
_token_id
in
output_ids
or
generation_config
.
eos_token_id
in
output_ids
if
decode_
in_batch
.
item
()
:
if
not
decode_
sequentially
:
output_values
=
self
.
audio_encoder
.
decode
(
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
output_ids
,
audio_scales
=
audio_scales
,
audio_scales
=
audio_scales
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment