Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
813df4d2
"test/old-api/spectests.cpp" did not exist on "fa0af88dfeef3c6ed06296b34989d548032b13f0"
Commit
813df4d2
authored
Feb 21, 2024
by
Yoach Lacombe
Browse files
update modeling code with prompt concat
parent
ca2cd16d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
16 deletions
+123
-16
stable_speech/__init__.py
stable_speech/__init__.py
+2
-0
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+4
-1
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+117
-15
No files found.
stable_speech/__init__.py
View file @
813df4d2
from
.configuration_stable_speech
import
StableSpeechConfig
,
StableSpeechDecoderConfig
from
.modeling_stable_speech
import
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
\ No newline at end of file
stable_speech/configuration_stable_speech.py
View file @
813df4d2
...
@@ -137,6 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
...
@@ -137,6 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
documentation from [`PretrainedConfig`] for more information.
Args:
Args:
prompt_embed_dim (`int`, *optional*, defaults to 1024):
Dimensionality of the prompt embedding layer.
kwargs (*optional*):
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
Dictionary of keyword arguments. Notably:
...
@@ -187,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
...
@@ -187,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
model_type
=
"stable_speech"
model_type
=
"stable_speech"
is_composition
=
True
is_composition
=
True
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
prompt_embed_dim
=
1024
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
if
"text_encoder"
not
in
kwargs
or
"audio_encoder"
not
in
kwargs
or
"decoder"
not
in
kwargs
:
if
"text_encoder"
not
in
kwargs
or
"audio_encoder"
not
in
kwargs
or
"decoder"
not
in
kwargs
:
raise
ValueError
(
"Config has to be initialized with text_encoder, audio_encoder and decoder config"
)
raise
ValueError
(
"Config has to be initialized with text_encoder, audio_encoder and decoder config"
)
...
@@ -200,6 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
...
@@ -200,6 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config
=
kwargs
.
pop
(
"decoder"
)
decoder_config
=
kwargs
.
pop
(
"decoder"
)
self
.
prompt_embed_dim
=
prompt_embed_dim
self
.
text_encoder
=
AutoConfig
.
for_model
(
text_encoder_model_type
,
**
text_encoder_config
)
self
.
text_encoder
=
AutoConfig
.
for_model
(
text_encoder_model_type
,
**
text_encoder_config
)
self
.
audio_encoder
=
AutoConfig
.
for_model
(
audio_encoder_model_type
,
**
audio_encoder_config
)
self
.
audio_encoder
=
AutoConfig
.
for_model
(
audio_encoder_model_type
,
**
audio_encoder_config
)
self
.
decoder
=
StableSpeechDecoderConfig
(
**
decoder_config
)
self
.
decoder
=
StableSpeechDecoderConfig
(
**
decoder_config
)
...
...
stable_speech/modeling_stable_speech.py
View file @
813df4d2
...
@@ -689,6 +689,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
...
@@ -689,6 +689,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -724,6 +726,22 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
...
@@ -724,6 +726,22 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if
inputs_embeds
is
None
:
if
inputs_embeds
is
None
:
inputs_embeds
=
sum
([
self
.
embed_tokens
[
codebook
](
input
[:,
codebook
])
for
codebook
in
range
(
num_codebooks
)])
inputs_embeds
=
sum
([
self
.
embed_tokens
[
codebook
](
input
[:,
codebook
])
for
codebook
in
range
(
num_codebooks
)])
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if
prompt_hidden_states
is
not
None
:
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
# TODO: verify if prompt attention mask is required
# As it is, the masked ids from the prompt will still count in the positions embeddings
if
prompt_attention_mask
is
not
None
and
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
attention_mask
],
dim
=
1
)
elif
prompt_attention_mask
is
not
None
:
logger
.
warning_once
(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
(
input_shape
,
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)])
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
...
@@ -862,6 +880,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
...
@@ -862,6 +880,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -884,6 +904,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
...
@@ -884,6 +904,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
encoder_attention_mask
=
encoder_attention_mask
,
encoder_attention_mask
=
encoder_attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
prompt_hidden_states
=
prompt_hidden_states
,
prompt_attention_mask
=
prompt_attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
...
@@ -951,6 +973,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -951,6 +973,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -962,7 +986,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -962,7 +986,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
return_dict
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithCrossAttentions
]:
)
->
Union
[
Tuple
,
CausalLMOutputWithCrossAttentions
]:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length
, num_codebooks
)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
...
@@ -976,6 +1000,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -976,6 +1000,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask
=
attention_mask
,
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
encoder_attention_mask
=
encoder_attention_mask
,
prompt_hidden_states
=
prompt_hidden_states
,
prompt_attention_mask
=
prompt_attention_mask
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
...
@@ -992,7 +1018,17 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -992,7 +1018,17 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
raise
NotImplementedError
(
"Training is not implemented for StableSpeech."
)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits
=
lm_logits
[:,:,
-
labels
.
shape
[
1
]:]
loss_fct
=
CrossEntropyLoss
()
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
pad_token_id
,
-
100
)
loss
=
loss_fct
(
logits
.
transpose
(
1
,
3
),
labels
)
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits
=
lm_logits
.
reshape
(
-
1
,
*
lm_logits
.
shape
[
2
:])
lm_logits
=
lm_logits
.
reshape
(
-
1
,
*
lm_logits
.
shape
[
2
:])
...
@@ -1016,6 +1052,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1016,6 +1052,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask
=
None
,
attention_mask
=
None
,
encoder_hidden_states
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
encoder_attention_mask
=
None
,
prompt_hidden_states
=
None
,
prompt_attention_mask
=
None
,
head_mask
=
None
,
head_mask
=
None
,
cross_attn_head_mask
=
None
,
cross_attn_head_mask
=
None
,
past_key_values
=
None
,
past_key_values
=
None
,
...
@@ -1040,15 +1078,30 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
...
@@ -1040,15 +1078,30 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
input_ids
=
input_ids
.
repeat
((
2
,
1
))
input_ids
=
input_ids
.
repeat
((
2
,
1
))
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
repeat
((
2
,
1
))
attention_mask
=
attention_mask
.
repeat
((
2
,
1
))
if
prompt_hidden_states
is
not
None
:
prompt_hidden_states
=
torch
.
concatenate
(
[
prompt_hidden_states
,
torch
.
zeros_like
(
prompt_hidden_states
)],
dim
=
0
)
if
prompt_attention_mask
is
not
None
:
prompt_attention_mask
=
torch
.
concatenate
(
prompt_attention_mask
,
torch
.
zeros_like
(
prompt_attention_mask
),
dim
=
0
)
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
input_ids
=
input_ids
[:,
-
1
:]
input_ids
=
input_ids
[:,
-
1
:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states
=
None
return
{
return
{
"input_ids"
:
input_ids
,
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"attention_mask"
:
attention_mask
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"encoder_attention_mask"
:
encoder_attention_mask
,
"encoder_attention_mask"
:
encoder_attention_mask
,
"prompt_hidden_states"
:
prompt_hidden_states
,
"prompt_attention_mask"
:
prompt_attention_mask
,
"head_mask"
:
head_mask
,
"head_mask"
:
head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"past_key_values"
:
past_key_values
,
"past_key_values"
:
past_key_values
,
...
@@ -1483,6 +1536,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1483,6 +1536,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
and
self
.
decoder
.
config
.
cross_attention_hidden_size
is
None
and
self
.
decoder
.
config
.
cross_attention_hidden_size
is
None
):
):
self
.
enc_to_dec_proj
=
nn
.
Linear
(
self
.
text_encoder
.
config
.
hidden_size
,
self
.
decoder
.
config
.
hidden_size
)
self
.
enc_to_dec_proj
=
nn
.
Linear
(
self
.
text_encoder
.
config
.
hidden_size
,
self
.
decoder
.
config
.
hidden_size
)
# prompt embeddings
self
.
embed_prompts
=
nn
.
Embedding
(
config
.
prompt_embed_dim
,
self
.
decoder
.
config
.
hidden_size
)
if
self
.
text_encoder
.
get_output_embeddings
()
is
not
None
:
if
self
.
text_encoder
.
get_output_embeddings
()
is
not
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -1496,8 +1553,19 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1496,8 +1553,19 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
)
# tie text encoder, decoder weights if config set accordingly
# Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly
self
.
tie_weights
()
self
.
post_init
()
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_factor
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Conv1d
)):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
# tie text encoder & decoder if needed
# tie text encoder & decoder if needed
...
@@ -1768,6 +1836,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1768,6 +1836,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values
:
Tuple
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
past_key_values
:
Tuple
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
decoder_inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
decoder_inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_input_ids
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
...
@@ -1844,6 +1915,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1844,6 +1915,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
encoder_hidden_states
=
encoder_hidden_states
*
attention_mask
[...,
None
]
encoder_hidden_states
=
encoder_hidden_states
*
attention_mask
[...,
None
]
if
prompt_hidden_states
is
None
:
if
prompt_input_ids
is
not
None
:
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
# TODO: do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states?
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
decoder_input_ids
=
shift_tokens_right
(
decoder_input_ids
=
shift_tokens_right
(
...
@@ -1876,29 +1952,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1876,29 +1952,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
attention_mask
=
decoder_attention_mask
,
attention_mask
=
decoder_attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
attention_mask
,
encoder_attention_mask
=
attention_mask
,
prompt_hidden_states
=
prompt_hidden_states
,
prompt_attention_mask
=
prompt_attention_mask
,
inputs_embeds
=
decoder_inputs_embeds
,
inputs_embeds
=
decoder_inputs_embeds
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
use_cache
=
use_cache
,
use_cache
=
use_cache
,
past_key_values
=
past_key_values
,
past_key_values
=
past_key_values
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
**
kwargs_decoder
,
**
kwargs_decoder
,
)
)
loss
=
None
if
labels
is
not
None
:
logits
=
decoder_outputs
.
logits
if
return_dict
else
decoder_outputs
[
0
]
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
config
.
vocab_size
),
labels
.
view
(
-
1
))
if
not
return_dict
:
if
not
return_dict
:
if
loss
is
not
None
:
return
decoder_outputs
+
(
encoder_hidden_states
,)
return
(
loss
,)
+
decoder_outputs
+
encoder_outputs
else
:
return
decoder_outputs
+
encoder_outputs
return
Seq2SeqLMOutput
(
return
Seq2SeqLMOutput
(
loss
=
loss
,
loss
=
decoder_outputs
.
loss
,
logits
=
decoder_outputs
.
logits
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
decoder_outputs
.
past_key_values
,
past_key_values
=
decoder_outputs
.
past_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
...
@@ -1917,6 +1987,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1917,6 +1987,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
head_mask
=
None
,
head_mask
=
None
,
decoder_attention_mask
=
None
,
decoder_attention_mask
=
None
,
decoder_head_mask
=
None
,
decoder_head_mask
=
None
,
prompt_hidden_states
=
None
,
prompt_attention_mask
=
None
,
cross_attn_head_mask
=
None
,
cross_attn_head_mask
=
None
,
use_cache
=
None
,
use_cache
=
None
,
encoder_outputs
=
None
,
encoder_outputs
=
None
,
...
@@ -1940,6 +2012,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1940,6 +2012,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
decoder_input_ids
=
decoder_input_ids
.
repeat
((
2
,
1
))
decoder_input_ids
=
decoder_input_ids
.
repeat
((
2
,
1
))
if
decoder_attention_mask
is
not
None
:
if
decoder_attention_mask
is
not
None
:
decoder_attention_mask
=
decoder_attention_mask
.
repeat
((
2
,
1
))
decoder_attention_mask
=
decoder_attention_mask
.
repeat
((
2
,
1
))
if
prompt_hidden_states
is
not
None
:
# TODO: ? we probably don't want to keep guidance scale here ? different task than musicgeneration
prompt_hidden_states
=
torch
.
concatenate
([
prompt_hidden_states
,
torch
.
zeros_like
(
prompt_hidden_states
)],
dim
=
0
)
if
past_key_values
is
not
None
:
if
past_key_values
is
not
None
:
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
...
@@ -1952,6 +2027,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -1952,6 +2027,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
remove_prefix_length
=
decoder_input_ids
.
shape
[
1
]
-
1
remove_prefix_length
=
decoder_input_ids
.
shape
[
1
]
-
1
decoder_input_ids
=
decoder_input_ids
[:,
remove_prefix_length
:]
decoder_input_ids
=
decoder_input_ids
[:,
remove_prefix_length
:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states
=
None
return
{
return
{
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
...
@@ -2058,6 +2137,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2058,6 +2137,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_kwargs
[
"encoder_outputs"
]
=
BaseModelOutput
(
last_hidden_state
=
last_hidden_state
)
model_kwargs
[
"encoder_outputs"
]
=
BaseModelOutput
(
last_hidden_state
=
last_hidden_state
)
return
model_kwargs
return
model_kwargs
def
_prepare_prompt_kwargs_for_generation
(
self
,
prompt_input_ids
,
model_kwargs
):
model_kwargs
[
"prompt_hidden_states"
]
=
self
.
embed_prompts
(
prompt_input_ids
)
return
model_kwargs
def
_prepare_audio_encoder_kwargs_for_generation
(
def
_prepare_audio_encoder_kwargs_for_generation
(
self
,
input_values
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
self
,
input_values
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
...
@@ -2110,6 +2193,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2110,6 +2193,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
raise
NotImplementedError
(
raise
NotImplementedError
(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
...
@@ -2143,6 +2227,16 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2143,6 +2227,16 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
batch_size
=
value
.
shape
[
0
]
batch_size
=
value
.
shape
[
0
]
break
break
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
def
freeze_encoders
(
self
,
freeze_text_encoder
=
True
):
if
freeze_text_encoder
:
for
param
in
self
.
text_encoder
.
parameters
():
param
.
requires_grad
=
False
self
.
text_encoder
.
_requires_grad
=
False
for
param
in
self
.
audio_encoder
.
parameters
():
param
.
requires_grad
=
False
self
.
audio_encoder
.
_requires_grad
=
False
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
generate
(
def
generate
(
...
@@ -2277,6 +2371,13 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2277,6 +2371,13 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_input_name
,
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
guidance_scale
=
generation_config
.
guidance_scale
,
)
)
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
# `prompt_hidden_states` are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_prompt_kwargs_for_generation
(
model_kwargs
[
"prompt_input_ids"
],
model_kwargs
,
)
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
model_kwargs
=
self
.
_prepare_audio_encoder_kwargs_for_generation
(
model_kwargs
=
self
.
_prepare_audio_encoder_kwargs_for_generation
(
...
@@ -2455,6 +2556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
...
@@ -2455,6 +2556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
def
get_unconditional_inputs
(
self
,
num_samples
=
1
):
def
get_unconditional_inputs
(
self
,
num_samples
=
1
):
"""
"""
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
feature extractor or tokenizer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment