Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
d0140745
Commit
d0140745
authored
Feb 27, 2024
by
Yoach Lacombe
Browse files
update code: fix accelerate, fix delay pattern mask, improve generation
parent
ee12a812
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
140 additions
and
116 deletions
+140
-116
example_configs/librispeech_tts_r.json
example_configs/librispeech_tts_r.json
+2
-2
example_configs/librispeech_tts_r_dummy.json
example_configs/librispeech_tts_r_dummy.json
+9
-9
init_dummy_model.py
init_dummy_model.py
+3
-3
init_model.py
init_model.py
+6
-6
run_stable_speech_training.py
run_stable_speech_training.py
+24
-25
stable_speech/__init__.py
stable_speech/__init__.py
+1
-1
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+5
-5
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+90
-65
No files found.
example_configs/librispeech_tts_r.json
View file @
d0140745
...
...
@@ -36,8 +36,8 @@
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
20
50
,
"decoder_start_token_id"
:
204
8
,
"pad_token_id"
:
20
48
,
"decoder_start_token_id"
:
204
9
,
"do_train"
:
true
,
"num_train_epochs"
:
120
,
...
...
example_configs/librispeech_tts_r_dummy.json
View file @
d0140745
...
...
@@ -24,8 +24,8 @@
"description_column_name"
:
"text_description"
,
"prompt_column_name"
:
"text"
,
"max_train_samples"
:
12
,
"max_eval_samples"
:
12
,
"max_train_samples"
:
4
,
"max_eval_samples"
:
4
,
"max_duration_in_seconds"
:
30
,
...
...
@@ -36,14 +36,14 @@
"preprocessing_num_workers"
:
1
,
"pad_token_id"
:
20
50
,
"decoder_start_token_id"
:
204
8
,
"pad_token_id"
:
20
48
,
"decoder_start_token_id"
:
204
9
,
"do_train"
:
true
,
"num_train_epochs"
:
20
,
"num_train_epochs"
:
1
20
,
"gradient_accumulation_steps"
:
1
,
"gradient_checkpointing"
:
false
,
"per_device_train_batch_size"
:
3
,
"per_device_train_batch_size"
:
2
,
"learning_rate"
:
1e-3
,
"adam_beta1"
:
0.9
,
"adam_beta2"
:
0.999
,
...
...
@@ -60,10 +60,10 @@
"predict_with_generate"
:
true
,
"include_inputs_for_metrics"
:
true
,
"evaluation_strategy"
:
"steps"
,
"eval_steps"
:
1
0
,
"per_device_eval_batch_size"
:
3
,
"eval_steps"
:
3
0
,
"per_device_eval_batch_size"
:
2
,
"generation_max_length"
:
400
,
"do_sample"
:
tru
e
,
"do_sample"
:
fals
e
,
"logging_steps"
:
15
,
...
...
init_dummy_model.py
View file @
d0140745
...
...
@@ -34,9 +34,9 @@ model = StableSpeechForConditionalGeneration.from_sub_models_pretrained(
)
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
204
8
model
.
generation_config
.
pad_token_id
=
20
50
model
.
generation_config
.
eos_token_id
=
204
9
model
.
generation_config
.
decoder_start_token_id
=
204
9
model
.
generation_config
.
pad_token_id
=
20
48
model
.
generation_config
.
eos_token_id
=
204
8
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
...
...
init_model.py
View file @
d0140745
...
...
@@ -18,7 +18,7 @@ decoder_config = StableSpeechDecoderConfig(
decoder
=
StableSpeechForCausalLM
(
decoder_config
)
decoder
.
save_pretrained
(
"/
raid
/yoach/
tmp
/decoder/"
)
decoder
.
save_pretrained
(
"/
home
/yoach/
dataspeech/artefacts
/decoder/"
)
t5
=
AutoConfig
.
from_pretrained
(
"t5-base"
)
...
...
@@ -26,18 +26,18 @@ t5 = AutoConfig.from_pretrained("t5-base")
model
=
StableSpeechForConditionalGeneration
.
from_sub_models_pretrained
(
text_encoder_pretrained_model_name_or_path
=
"t5-base"
,
audio_encoder_pretrained_model_name_or_path
=
"facebook/encodec_32khz"
,
decoder_pretrained_model_name_or_path
=
"/
raid
/yoach/
tmp
/decoder/"
,
decoder_pretrained_model_name_or_path
=
"/
home
/yoach/
dataspeech/artefacts
/decoder/"
,
vocab_size
=
t5
.
vocab_size
)
# set the appropriate bos/pad token ids
model
.
generation_config
.
decoder_start_token_id
=
204
8
model
.
generation_config
.
pad_token_id
=
20
50
model
.
generation_config
.
eos_token_id
=
204
9
model
.
generation_config
.
decoder_start_token_id
=
204
9
model
.
generation_config
.
pad_token_id
=
20
48
model
.
generation_config
.
eos_token_id
=
204
8
# set other default generation config params
model
.
generation_config
.
max_length
=
int
(
30
*
model
.
audio_encoder
.
config
.
frame_rate
)
model
.
generation_config
.
do_sample
=
False
# True
model
.
generation_config
.
guidance_scale
=
1
# 3.0
model
.
save_pretrained
(
"/raid/yoach/tmp/small-stable-speech-untrained/"
)
\ No newline at end of file
model
.
save_pretrained
(
"/home/yoach/dataspeech/artefacts/small-stable-speech-untrained/"
)
\ No newline at end of file
run_stable_speech_training.py
View file @
d0140745
...
...
@@ -26,6 +26,8 @@ import shutil
import
warnings
import
math
import
time
from
multiprocess
import
set_start_method
import
evaluate
from
tqdm
import
tqdm
...
...
@@ -63,7 +65,7 @@ from accelerate import Accelerator
from
accelerate.utils
import
set_seed
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
from
stable_speech
import
StableSpeechForConditionalGeneration
,
StableSpeechConfig
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
if
is_wandb_available
():
from
wandb
import
Audio
...
...
@@ -516,15 +518,10 @@ class DataCollatorStableSpeechWithPadding:
# (bsz, seq_len, num_codebooks)
labels
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
labels
,
batch_first
=
True
,
padding_value
=-
100
)
delay_pattern_mask
=
[
torch
.
tensor
(
feature
[
"label_delay_pattern_mask"
]).
transpose
(
0
,
1
)
for
feature
in
features
]
# (bsz, seq_len, num_codebooks)
delay_pattern_mask
=
torch
.
nn
.
utils
.
rnn
.
pad_sequence
(
delay_pattern_mask
,
batch_first
=
True
,
padding_value
=-
100
)
input_ids
=
[{
"input_ids"
:
feature
[
"input_ids"
]}
for
feature
in
features
]
input_ids
=
self
.
description_tokenizer
.
pad
(
input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
batch
=
{
"labels"
:
labels
,
"label_delay_pattern_mask"
:
delay_pattern_mask
,
**
input_ids
}
batch
=
{
"labels"
:
labels
,
**
input_ids
}
prompt_input_ids
=
[{
"input_ids"
:
feature
[
"prompt_input_ids"
]}
for
feature
in
features
]
prompt_input_ids
=
self
.
prompt_tokenizer
.
pad
(
prompt_input_ids
,
return_tensors
=
"pt"
,
padding
=
self
.
padding
,
pad_to_multiple_of
=
self
.
pad_to_multiple_of
)
...
...
@@ -1014,23 +1011,30 @@ def main():
len_
=
int
(
all_ratios
[
idx
]
*
all_lens
[
idx
])
labels
=
labels
[:,
:,
:
len_
]
# TODO: remove, only for test
labels
=
labels
[:,
:,
:(
len_
)
%
10
+
20
]
# add bos and eos token column
labels
=
torch
.
cat
([
bos_labels
,
labels
,
eos_labels
.
to
(
labels
.
device
).
to
(
labels
.
dtype
)],
dim
=-
1
)
labels
=
labels
[:,
:,
:(
len_
)
%
10
+
20
]
# TODO: change
# add bos
labels
=
torch
.
cat
([
bos_labels
,
labels
],
dim
=-
1
)
labels
,
delay_pattern_mask
=
model
.
decoder
.
build_delay_pattern_mask
(
labels
,
labels
,
delay_pattern_mask
=
build_delay_pattern_mask
(
labels
,
bos_token_id
=
audio_encoder_bos_token_id
,
pad_token_id
=
audio_encoder_pad_token_id
,
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
)
pad_token_id
=
audio_encoder_eos_token_id
,
max_length
=
labels
.
shape
[
-
1
]
+
num_codebooks
,
num_codebooks
=
num_codebooks
)
labels
=
model
.
decoder
.
apply_delay_pattern_mask
(
labels
,
delay_pattern_mask
)
# the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
# to take care of EOS
# we want labels to look like this:
# - [B, a, b, E, E, E, E]
# - [B, B, c, d, E, E, E]
# - [B, B, B, e, f, E, E]
# - [B, B, B, B, g, h, E]
labels
=
torch
.
where
(
delay_pattern_mask
==-
1
,
audio_encoder_eos_token_id
,
delay_pattern_mask
)
# the first timestamp is associated to a row full of BOS, let's get rid of it
sample
[
"labels"
]
=
labels
[:,
1
:]
sample
[
"label
_delay_pattern_mask"
]
=
delay_pattern_mask
[:,
1
:]
# we also remove the last timestampts (full of PAD)
sample
[
"label
s"
]
=
labels
[:,
1
:]
.
cpu
()
return
sample
# TODO: done multiple times, how to deal with it.
...
...
@@ -1047,12 +1051,6 @@ def main():
del
generate_labels
if
data_args
.
add_audio_samples_to_wandb
and
"wandb"
in
training_args
.
report_to
:
if
is_wandb_available
():
from
transformers.integrations
import
WandbCallback
else
:
raise
ValueError
(
"`args.add_audio_samples_to_wandb=True` but wandb is not installed. See https://docs.wandb.ai/quickstart to install."
)
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
...
...
@@ -1467,4 +1465,5 @@ def main():
if
__name__
==
"__main__"
:
set_start_method
(
"spawn"
)
main
()
\ No newline at end of file
stable_speech/__init__.py
View file @
d0140745
from
.configuration_stable_speech
import
StableSpeechConfig
,
StableSpeechDecoderConfig
from
.modeling_stable_speech
import
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
\ No newline at end of file
from
.modeling_stable_speech
import
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
,
apply_delay_pattern_mask
,
build_delay_pattern_mask
\ No newline at end of file
stable_speech/configuration_stable_speech.py
View file @
d0140745
...
...
@@ -38,7 +38,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 20
50
):
vocab_size (`int`, *optional*, defaults to 20
49
):
Vocabulary size of the StableSpeechDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`StableSpeechDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
...
...
@@ -81,7 +81,7 @@ class StableSpeechDecoderConfig(PretrainedConfig):
def
__init__
(
self
,
vocab_size
=
20
50
,
# vocab size = 2048 (encodec vocab size) +
2
(
bos,
eos)
vocab_size
=
20
49
,
# vocab size = 2048 (encodec vocab size) +
1
(eos)
max_position_embeddings
=
2048
,
num_hidden_layers
=
24
,
ffn_dim
=
4096
,
...
...
@@ -96,9 +96,9 @@ class StableSpeechDecoderConfig(PretrainedConfig):
initializer_factor
=
0.02
,
scale_embedding
=
False
,
num_codebooks
=
4
,
pad_token_id
=
20
50
,
bos_token_id
=
204
8
,
eos_token_id
=
204
9
,
pad_token_id
=
20
48
,
bos_token_id
=
204
9
,
eos_token_id
=
204
8
,
tie_word_embeddings
=
False
,
**
kwargs
,
):
...
...
stable_speech/modeling_stable_speech.py
View file @
d0140745
...
...
@@ -60,6 +60,77 @@ MUSICGEN_PRETRAINED_MODEL_ARCHIVE_LIST = [
# See all StableSpeech models at https://huggingface.co/models?filter=stable_speech
]
def
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len
=
input_ids
.
shape
[
-
1
]
decoder_pad_token_mask
=
decoder_pad_token_mask
[...,
:
seq_len
]
input_ids
=
torch
.
where
(
decoder_pad_token_mask
==
-
1
,
input_ids
,
decoder_pad_token_mask
)
return
input_ids
def
build_delay_pattern_mask
(
input_ids
:
torch
.
LongTensor
,
bos_token_id
:
int
,
pad_token_id
:
int
,
max_length
:
int
,
num_codebooks
:
int
):
"""Build a delayed pattern mask to the input_ids. Each codebook is offset by the previous codebook by
one, giving a delayed pattern mask at the start of sequence and end of sequence. Take the example where there
are 4 codebooks and a max sequence length of 8, we have the delayed pattern mask of shape `(codebooks,
seq_len)`:
- [B, -1, -1, -1, -1, P, P, P]
- [B, B, -1, -1, -1, -1, P, P]
- [B, B, B, -1, -1, -1, -1, P]
- [B, B, B, B, -1, -1, -1, -1]
where P is the special padding token id and -1 indicates that the token is valid for prediction. If we include
a prompt (decoder input ids), the -1 positions indicate where new tokens should be predicted. Otherwise, the
mask is set to the value in the prompt:
- [B, a, b, -1, -1, P, P, P]
- [B, B, c, d, -1, -1, P, P]
- [B, B, B, e, f, -1, -1, P]
- [B, B, B, B, g, h, -1, -1]
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids
=
input_ids
.
reshape
(
-
1
,
num_codebooks
,
input_ids
.
shape
[
-
1
])
bsz
,
num_codebooks
,
seq_len
=
input_ids
.
shape
input_ids_shifted
=
(
torch
.
ones
((
bsz
,
num_codebooks
,
max_length
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
-
1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if
max_length
<
2
*
num_codebooks
-
1
:
return
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
),
input_ids_shifted
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for
codebook
in
range
(
num_codebooks
):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted
[:,
codebook
,
codebook
:
seq_len
+
codebook
]
=
input_ids
[:,
codebook
]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern
=
torch
.
triu
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
),
diagonal
=
max_length
-
num_codebooks
+
1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern
=
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids
=
input_ids
[:,
0
,
:]
start_ids
=
(
first_codebook_ids
==
-
1
).
nonzero
()[:,
1
]
if
len
(
start_ids
)
>
0
:
first_start_id
=
min
(
start_ids
)
else
:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id
=
seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask
=
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
input_ids
=
input_ids
[...,
:
first_start_id
].
reshape
(
bsz
*
num_codebooks
,
-
1
)
return
input_ids
,
pattern_mask
@
dataclass
class
StableSpeechUnconditionalInput
(
ModelOutput
):
...
...
@@ -982,7 +1053,6 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
...
@@ -1031,16 +1101,9 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
bos_token_id
,
-
100
)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
pad_token_id
,
-
100
)
# loss = loss_fct(logits.transpose(1,3), labels)
# -100 labels are ignored
# TODO: probably no need for label_delay_pattern_mask
# mask = label_delay_pattern_mask[:, :labels.shape[1]]
# mask = (labels != self.generation_config.bos_token_id)&(labels != -100)
mask
=
(
labels
!=
-
100
)
# we use every codebooks token AND one single EOS at the end of each codebooks
mask
=
(
input_ids
.
transpose
(
1
,
2
)
!=
self
.
config
.
eos_token_id
)
&
((
labels
!=
-
100
))
# per codebook cross-entropy
for
codebook
in
range
(
self
.
config
.
num_codebooks
):
...
...
@@ -1152,60 +1215,14 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
where a-h indicate the input prompt (decoder input ids) that are offset by 1. Now, we only override the -1
tokens in our prediction.
"""
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
input_ids
=
input_ids
.
reshape
(
-
1
,
self
.
num_codebooks
,
input_ids
.
shape
[
-
1
])
bsz
,
num_codebooks
,
seq_len
=
input_ids
.
shape
max_length
=
max_length
if
max_length
is
not
None
else
self
.
generation_config
.
max_length
input_ids_shifted
=
(
torch
.
ones
((
bsz
,
num_codebooks
,
max_length
),
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
*
-
1
)
# we only apply the mask if we have a large enough seq len - otherwise we return as is
if
max_length
<
2
*
num_codebooks
-
1
:
return
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
),
input_ids_shifted
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
# fill the shifted ids with the prompt entries, offset by the codebook idx
for
codebook
in
range
(
num_codebooks
):
# mono channel - loop over the codebooks one-by-one
input_ids_shifted
[:,
codebook
,
codebook
:
seq_len
+
codebook
]
=
input_ids
[:,
codebook
]
# construct a pattern mask that indicates the positions of padding tokens for each codebook
# first fill the upper triangular part (the EOS padding)
eos_delay_pattern
=
torch
.
triu
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
),
diagonal
=
max_length
-
num_codebooks
+
1
)
# then fill the lower triangular part (the BOS padding)
bos_delay_pattern
=
torch
.
tril
(
torch
.
ones
((
num_codebooks
,
max_length
),
dtype
=
torch
.
bool
))
bos_mask
=
~
(
bos_delay_pattern
).
to
(
input_ids
.
device
)
eos_mask
=
~
(
eos_delay_pattern
).
to
(
input_ids
.
device
)
mask
=
~
(
bos_delay_pattern
+
eos_delay_pattern
).
to
(
input_ids
.
device
)
input_ids
=
mask
*
input_ids_shifted
+
~
bos_mask
*
bos_token_id
+
~
eos_mask
*
pad_token_id
# find the first position to start generating - this is the first place we have the -1 token
# and will always be in the first codebook (since it has no codebook offset)
first_codebook_ids
=
input_ids
[:,
0
,
:]
start_ids
=
(
first_codebook_ids
==
-
1
).
nonzero
()[:,
1
]
if
len
(
start_ids
)
>
0
:
first_start_id
=
min
(
start_ids
)
else
:
# we have no tokens that need to be filled - return entire matrix of input ids
first_start_id
=
seq_len
# (bsz * num_codebooks, seq_len) -> (bsz, num_codebooks, seq_len)
pattern_mask
=
input_ids
.
reshape
(
bsz
*
num_codebooks
,
-
1
)
input_ids
=
input_ids
[...,
:
first_start_id
].
reshape
(
bsz
*
num_codebooks
,
-
1
)
return
input_ids
,
pattern_mask
return
build_delay_pattern_mask
(
input_ids
,
bos_token_id
,
pad_token_id
,
max_length
,
self
.
num_codebooks
)
@
staticmethod
def
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
):
"""Apply a delay pattern mask to the decoder input ids, only preserving predictions where
the mask is set to -1, and otherwise setting to the value detailed in the mask."""
seq_len
=
input_ids
.
shape
[
-
1
]
decoder_pad_token_mask
=
decoder_pad_token_mask
[...,
:
seq_len
]
input_ids
=
torch
.
where
(
decoder_pad_token_mask
==
-
1
,
input_ids
,
decoder_pad_token_mask
)
return
input_ids
return
apply_delay_pattern_mask
(
input_ids
,
decoder_pad_token_mask
)
@
torch
.
no_grad
()
def
generate
(
...
...
@@ -1219,7 +1236,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
**
kwargs
,
):
"""
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
...
...
@@ -1868,7 +1885,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
label_delay_pattern_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
...
...
@@ -1991,7 +2007,6 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values
=
past_key_values
,
return_dict
=
return_dict
,
labels
=
labels
,
label_delay_pattern_mask
=
label_delay_pattern_mask
,
**
kwargs_decoder
,
)
...
...
@@ -2074,6 +2089,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"head_mask"
:
head_mask
,
"decoder_head_mask"
:
decoder_head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"prompt_hidden_states"
:
prompt_hidden_states
,
"prompt_attention_mask"
:
prompt_attention_mask
,
"use_cache"
:
use_cache
,
}
...
...
@@ -2564,9 +2581,17 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
# apply the pattern mask to the final ids
output_ids
=
self
.
decoder
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"decoder_delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# TODO: probably won't work...
output_ids
=
output_ids
[(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"decoder_delay_pattern_mask"
]
!=
generation_config
.
pad_token_id
)].
reshape
(
_
,
mask
=
self
.
decoder
.
build_delay_pattern_mask
(
input_ids
,
bos_token_id
=
generation_config
.
bos_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
max_length
=
output_ids
.
shape
[
1
],
)
mask
=
(
mask
!=
generation_config
.
bos_token_id
)
&
(
mask
!=
generation_config
.
pad_token_id
)
output_ids
=
output_ids
[
mask
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
...
...
@@ -2577,8 +2602,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
audio_scales
is
None
:
audio_scales
=
[
None
]
*
batch_size
decode_
in_batch
=
((
output_ids
==
generation_config
.
bos
_token_id
).
sum
()
+
(
output_ids
==
generation_config
.
eos_token_id
).
sum
())
==
0
if
decode_
in_batch
.
item
()
:
decode_
sequentially
=
generation_config
.
bos_token_id
in
output_ids
or
generation_config
.
pad
_token_id
in
output_ids
or
generation_config
.
eos_token_id
in
output_ids
if
not
decode_
sequentially
:
output_values
=
self
.
audio_encoder
.
decode
(
output_ids
,
audio_scales
=
audio_scales
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment