Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
813df4d2
Commit
813df4d2
authored
Feb 21, 2024
by
Yoach Lacombe
Browse files
update modeling code with prompt concat
parent
ca2cd16d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
16 deletions
+123
-16
stable_speech/__init__.py
stable_speech/__init__.py
+2
-0
stable_speech/configuration_stable_speech.py
stable_speech/configuration_stable_speech.py
+4
-1
stable_speech/modeling_stable_speech.py
stable_speech/modeling_stable_speech.py
+117
-15
No files found.
stable_speech/__init__.py
View file @
813df4d2
from
.configuration_stable_speech
import
StableSpeechConfig
,
StableSpeechDecoderConfig
from
.modeling_stable_speech
import
StableSpeechForCausalLM
,
StableSpeechForConditionalGeneration
\ No newline at end of file
stable_speech/configuration_stable_speech.py
View file @
813df4d2
...
...
@@ -137,6 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
prompt_embed_dim (`int`, *optional*, defaults to 1024):
Dimensionality of the prompt embedding layer.
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
...
...
@@ -187,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
model_type
=
"stable_speech"
is_composition
=
True
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
prompt_embed_dim
=
1024
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
"text_encoder"
not
in
kwargs
or
"audio_encoder"
not
in
kwargs
or
"decoder"
not
in
kwargs
:
raise
ValueError
(
"Config has to be initialized with text_encoder, audio_encoder and decoder config"
)
...
...
@@ -200,6 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config
=
kwargs
.
pop
(
"decoder"
)
self
.
prompt_embed_dim
=
prompt_embed_dim
self
.
text_encoder
=
AutoConfig
.
for_model
(
text_encoder_model_type
,
**
text_encoder_config
)
self
.
audio_encoder
=
AutoConfig
.
for_model
(
audio_encoder_model_type
,
**
audio_encoder_config
)
self
.
decoder
=
StableSpeechDecoderConfig
(
**
decoder_config
)
...
...
stable_speech/modeling_stable_speech.py
View file @
813df4d2
...
...
@@ -689,6 +689,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
...
@@ -724,6 +726,22 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if
inputs_embeds
is
None
:
inputs_embeds
=
sum
([
self
.
embed_tokens
[
codebook
](
input
[:,
codebook
])
for
codebook
in
range
(
num_codebooks
)])
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if
prompt_hidden_states
is
not
None
:
inputs_embeds
=
torch
.
cat
([
prompt_hidden_states
,
inputs_embeds
],
dim
=
1
)
input_shape
=
inputs_embeds
.
size
()[:
-
1
]
# TODO: verify if prompt attention mask is required
# As it is, the masked ids from the prompt will still count in the positions embeddings
if
prompt_attention_mask
is
not
None
and
attention_mask
is
not
None
:
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
attention_mask
],
dim
=
1
)
elif
prompt_attention_mask
is
not
None
:
logger
.
warning_once
(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
attention_mask
=
torch
.
cat
([
prompt_attention_mask
,
torch
.
ones
(
input_shape
,
device
=
self
.
device
,
dtype
=
prompt_attention_mask
.
dtype
)])
attention_mask
=
_prepare_4d_causal_attention_mask
(
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
...
...
@@ -862,6 +880,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
...
@@ -884,6 +904,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask
=
attention_mask
,
encoder_attention_mask
=
encoder_attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
prompt_hidden_states
=
prompt_hidden_states
,
prompt_attention_mask
=
prompt_attention_mask
,
head_mask
=
head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
past_key_values
=
past_key_values
,
...
...
@@ -951,6 +973,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
...
@@ -962,7 +986,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithCrossAttentions
]:
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length
, num_codebooks
)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
...
...
@@ -976,6 +1000,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask
=
attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_attention_mask
,
prompt_hidden_states
=
prompt_hidden_states
,
prompt_attention_mask
=
prompt_attention_mask
,
head_mask
=
head_mask
,
cross_attn_head_mask
=
cross_attn_head_mask
,
past_key_values
=
past_key_values
,
...
...
@@ -992,7 +1018,17 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss
=
None
if
labels
is
not
None
:
raise
NotImplementedError
(
"Training is not implemented for StableSpeech."
)
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits
=
lm_logits
[:,:,
-
labels
.
shape
[
1
]:]
loss_fct
=
CrossEntropyLoss
()
loss
=
torch
.
zeros
([],
device
=
self
.
device
)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels
=
labels
.
masked_fill
(
labels
==
self
.
config
.
pad_token_id
,
-
100
)
loss
=
loss_fct
(
logits
.
transpose
(
1
,
3
),
labels
)
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits
=
lm_logits
.
reshape
(
-
1
,
*
lm_logits
.
shape
[
2
:])
...
...
@@ -1016,6 +1052,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
prompt_hidden_states
=
None
,
prompt_attention_mask
=
None
,
head_mask
=
None
,
cross_attn_head_mask
=
None
,
past_key_values
=
None
,
...
...
@@ -1040,15 +1078,30 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
input_ids
=
input_ids
.
repeat
((
2
,
1
))
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
repeat
((
2
,
1
))
if
prompt_hidden_states
is
not
None
:
prompt_hidden_states
=
torch
.
concatenate
(
[
prompt_hidden_states
,
torch
.
zeros_like
(
prompt_hidden_states
)],
dim
=
0
)
if
prompt_attention_mask
is
not
None
:
prompt_attention_mask
=
torch
.
concatenate
(
prompt_attention_mask
,
torch
.
zeros_like
(
prompt_attention_mask
),
dim
=
0
)
if
past_key_values
is
not
None
:
input_ids
=
input_ids
[:,
-
1
:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states
=
None
return
{
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"encoder_hidden_states"
:
encoder_hidden_states
,
"encoder_attention_mask"
:
encoder_attention_mask
,
"prompt_hidden_states"
:
prompt_hidden_states
,
"prompt_attention_mask"
:
prompt_attention_mask
,
"head_mask"
:
head_mask
,
"cross_attn_head_mask"
:
cross_attn_head_mask
,
"past_key_values"
:
past_key_values
,
...
...
@@ -1483,6 +1536,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
and
self
.
decoder
.
config
.
cross_attention_hidden_size
is
None
):
self
.
enc_to_dec_proj
=
nn
.
Linear
(
self
.
text_encoder
.
config
.
hidden_size
,
self
.
decoder
.
config
.
hidden_size
)
# prompt embeddings
self
.
embed_prompts
=
nn
.
Embedding
(
config
.
prompt_embed_dim
,
self
.
decoder
.
config
.
hidden_size
)
if
self
.
text_encoder
.
get_output_embeddings
()
is
not
None
:
raise
ValueError
(
...
...
@@ -1496,8 +1553,19 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
# tie text encoder, decoder weights if config set accordingly
self
.
tie_weights
()
# Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly
self
.
post_init
()
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_factor
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Conv1d
)):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
tie_weights
(
self
):
# tie text encoder & decoder if needed
...
...
@@ -1768,6 +1836,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values
:
Tuple
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
decoder_inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_input_ids
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
...
...
@@ -1844,6 +1915,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if
attention_mask
is
not
None
:
encoder_hidden_states
=
encoder_hidden_states
*
attention_mask
[...,
None
]
if
prompt_hidden_states
is
None
:
if
prompt_input_ids
is
not
None
:
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
# TODO: do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states?
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
decoder_input_ids
=
shift_tokens_right
(
...
...
@@ -1876,29 +1952,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
attention_mask
=
decoder_attention_mask
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
attention_mask
,
prompt_hidden_states
=
prompt_hidden_states
,
prompt_attention_mask
=
prompt_attention_mask
,
inputs_embeds
=
decoder_inputs_embeds
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
use_cache
=
use_cache
,
past_key_values
=
past_key_values
,
return_dict
=
return_dict
,
labels
=
labels
,
**
kwargs_decoder
,
)
loss
=
None
if
labels
is
not
None
:
logits
=
decoder_outputs
.
logits
if
return_dict
else
decoder_outputs
[
0
]
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
config
.
vocab_size
),
labels
.
view
(
-
1
))
if
not
return_dict
:
if
loss
is
not
None
:
return
(
loss
,)
+
decoder_outputs
+
encoder_outputs
else
:
return
decoder_outputs
+
encoder_outputs
return
decoder_outputs
+
(
encoder_hidden_states
,)
return
Seq2SeqLMOutput
(
loss
=
loss
,
loss
=
decoder_outputs
.
loss
,
logits
=
decoder_outputs
.
logits
,
past_key_values
=
decoder_outputs
.
past_key_values
,
decoder_hidden_states
=
decoder_outputs
.
hidden_states
,
...
...
@@ -1917,6 +1987,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
head_mask
=
None
,
decoder_attention_mask
=
None
,
decoder_head_mask
=
None
,
prompt_hidden_states
=
None
,
prompt_attention_mask
=
None
,
cross_attn_head_mask
=
None
,
use_cache
=
None
,
encoder_outputs
=
None
,
...
...
@@ -1940,6 +2012,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
decoder_input_ids
=
decoder_input_ids
.
repeat
((
2
,
1
))
if
decoder_attention_mask
is
not
None
:
decoder_attention_mask
=
decoder_attention_mask
.
repeat
((
2
,
1
))
if
prompt_hidden_states
is
not
None
:
# TODO: ? we probably don't want to keep guidance scale here ? different task than musicgeneration
prompt_hidden_states
=
torch
.
concatenate
([
prompt_hidden_states
,
torch
.
zeros_like
(
prompt_hidden_states
)],
dim
=
0
)
if
past_key_values
is
not
None
:
past_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
...
...
@@ -1952,6 +2027,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
remove_prefix_length
=
decoder_input_ids
.
shape
[
1
]
-
1
decoder_input_ids
=
decoder_input_ids
[:,
remove_prefix_length
:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states
=
None
return
{
"input_ids"
:
None
,
# encoder_outputs is defined. input_ids not needed
...
...
@@ -2058,6 +2137,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_kwargs
[
"encoder_outputs"
]
=
BaseModelOutput
(
last_hidden_state
=
last_hidden_state
)
return
model_kwargs
def
_prepare_prompt_kwargs_for_generation
(
self
,
prompt_input_ids
,
model_kwargs
):
model_kwargs
[
"prompt_hidden_states"
]
=
self
.
embed_prompts
(
prompt_input_ids
)
return
model_kwargs
def
_prepare_audio_encoder_kwargs_for_generation
(
self
,
input_values
,
model_kwargs
,
model_input_name
:
Optional
[
str
]
=
None
...
...
@@ -2110,6 +2193,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
raise
NotImplementedError
(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
...
...
@@ -2143,6 +2227,16 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
batch_size
=
value
.
shape
[
0
]
break
return
torch
.
ones
((
batch_size
,
1
),
dtype
=
torch
.
long
,
device
=
self
.
device
)
*
bos_token_id
def
freeze_encoders
(
self
,
freeze_text_encoder
=
True
):
if
freeze_text_encoder
:
for
param
in
self
.
text_encoder
.
parameters
():
param
.
requires_grad
=
False
self
.
text_encoder
.
_requires_grad
=
False
for
param
in
self
.
audio_encoder
.
parameters
():
param
.
requires_grad
=
False
self
.
audio_encoder
.
_requires_grad
=
False
@
torch
.
no_grad
()
def
generate
(
...
...
@@ -2277,6 +2371,13 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_input_name
,
guidance_scale
=
generation_config
.
guidance_scale
,
)
if
"prompt_hidden_states"
not
in
model_kwargs
and
"prompt_input_ids"
in
model_kwargs
:
# `prompt_hidden_states` are created and added to `model_kwargs`
model_kwargs
=
self
.
_prepare_prompt_kwargs_for_generation
(
model_kwargs
[
"prompt_input_ids"
],
model_kwargs
,
)
if
"decoder_input_ids"
not
in
model_kwargs
and
"input_values"
in
model_kwargs
:
model_kwargs
=
self
.
_prepare_audio_encoder_kwargs_for_generation
(
...
...
@@ -2455,6 +2556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
def
get_unconditional_inputs
(
self
,
num_samples
=
1
):
"""
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment