Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
a4f464fe
Commit
a4f464fe
authored
Apr 08, 2024
by
Yoach Lacombe
Browse files
clean modeling code
parent
43087d4a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
60 deletions
+53
-60
parler_tts/configuration_parler_tts.py
parler_tts/configuration_parler_tts.py
+3
-2
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+50
-58
No files found.
parler_tts/configuration_parler_tts.py
View file @
a4f464fe
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to 2049):
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
num_hidden_layers (`int`, *optional*, defaults to 24):
...
@@ -138,7 +138,8 @@ class ParlerTTSConfig(PretrainedConfig):
...
@@ -138,7 +138,8 @@ class ParlerTTSConfig(PretrainedConfig):
Args:
Args:
vocab_size (`int`, *optional*, defaults to 1024):
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt # TODO.
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
kwargs (*optional*):
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
Dictionary of keyword arguments. Notably:
...
...
parler_tts/modeling_parler_tts.py
View file @
a4f464fe
...
@@ -219,7 +219,6 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
...
@@ -219,7 +219,6 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
position_ids
=
(
torch
.
arange
(
seq_len
)
+
past_key_values_length
).
to
(
input_ids
.
device
)
position_ids
=
(
torch
.
arange
(
seq_len
)
+
past_key_values_length
).
to
(
input_ids
.
device
)
# expand embeddings if needed
# expand embeddings if needed
if
seq_len
>
self
.
weights
.
size
(
0
):
if
seq_len
>
self
.
weights
.
size
(
0
):
# TODO: doesn't work
self
.
make_weights
(
seq_len
+
self
.
offset
,
self
.
embedding_dim
)
self
.
make_weights
(
seq_len
+
self
.
offset
,
self
.
embedding_dim
)
return
self
.
weights
.
index_select
(
0
,
position_ids
.
view
(
-
1
)).
detach
()
return
self
.
weights
.
index_select
(
0
,
position_ids
.
view
(
-
1
)).
detach
()
...
@@ -632,6 +631,25 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
...
@@ -632,6 +631,25 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
of `inputs_embeds`.
prompt_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input prompt sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
prompt_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding prompt token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
`past_key_values`).
...
@@ -683,6 +701,16 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
...
@@ -683,6 +701,16 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
- 1 for tokens that are **not masked**,
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of prompt hidden-states at the output of the initial embedding layer. Concatenated to the input embeds.
prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of prompt input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
...
@@ -738,7 +766,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
...
@@ -738,7 +766,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
self
.
num_codebooks
=
config
.
num_codebooks
self
.
num_codebooks
=
config
.
num_codebooks
self
.
embed_scale
=
math
.
sqrt
(
config
.
hidden_size
)
if
config
.
scale_embedding
else
1.0
self
.
embed_scale
=
math
.
sqrt
(
config
.
hidden_size
)
if
config
.
scale_embedding
else
1.0
# TODO
: not right dim
# TODO
(YL): actually doesn't need the +1 if initialized correctly. Too late to change now.
embed_dim
=
config
.
vocab_size
+
1
# + 1 for pad token id
embed_dim
=
config
.
vocab_size
+
1
# + 1 for pad token id
self
.
embed_tokens
=
nn
.
ModuleList
(
self
.
embed_tokens
=
nn
.
ModuleList
(
[
nn
.
Embedding
(
embed_dim
,
config
.
hidden_size
)
for
_
in
range
(
config
.
num_codebooks
)]
[
nn
.
Embedding
(
embed_dim
,
config
.
hidden_size
)
for
_
in
range
(
config
.
num_codebooks
)]
...
@@ -769,8 +797,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
...
@@ -769,8 +797,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -978,8 +1006,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
...
@@ -978,8 +1006,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -1071,8 +1099,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1071,8 +1099,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
@@ -1088,7 +1116,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1088,7 +1116,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
# TODO: delay_pattern_mask
Returns:
Returns:
"""
"""
...
@@ -1263,7 +1290,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1263,7 +1290,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
**
kwargs
,
**
kwargs
,
):
):
"""
"""
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head.
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
<Tip warning={true}>
...
@@ -1504,15 +1530,20 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
...
@@ -1504,15 +1530,20 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids
=
outputs
.
sequences
output_ids
=
outputs
.
sequences
else
:
else
:
output_ids
=
outputs
output_ids
=
outputs
# apply the pattern mask to the final ids
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[
_
,
mask
=
self
.
build_delay_pattern_mask
(
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
input_ids
,
&
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
eos_token_id
)
bos_token_id
=
generation_config
.
bos_token_id
,
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
pad_token_id
=
generation_config
.
pad_token_id
,
max_length
=
output_ids
.
shape
[
1
],
)
mask
=
(
mask
!=
generation_config
.
bos_token_id
)
&
(
mask
!=
generation_config
.
pad_token_id
)
output_ids
=
output_ids
[
mask
].
reshape
(
batch_size
,
self
.
num_codebooks
,
-
1
)
if
generation_config
.
return_dict_in_generate
:
if
generation_config
.
return_dict_in_generate
:
outputs
.
sequences
=
output_ids
outputs
.
sequences
=
output_ids
...
@@ -1856,7 +1887,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -1856,7 +1887,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
)
)
if
"config"
not
in
kwargs_decoder
:
if
"config"
not
in
kwargs_decoder
:
# TODO: reput AutoConfig once added to transformers
decoder_config
,
kwargs_decoder
=
ParlerTTSDecoderConfig
.
from_pretrained
(
decoder_config
,
kwargs_decoder
=
ParlerTTSDecoderConfig
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
,
return_unused_kwargs
=
True
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
,
return_unused_kwargs
=
True
)
)
...
@@ -1906,9 +1936,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -1906,9 +1936,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values
:
Tuple
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
past_key_values
:
Tuple
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
decoder_inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
decoder_inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_input_ids
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_input_ids
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
...
@@ -1989,10 +2019,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -1989,10 +2019,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
prompt_hidden_states
is
None
:
if
prompt_hidden_states
is
None
:
if
prompt_input_ids
is
not
None
:
if
prompt_input_ids
is
not
None
:
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
# TODO: verify prompt_attention_mask
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
# TODO: verify it does what's expected
decoder_input_ids
=
shift_tokens_right
(
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
).
transpose
(
1
,
2
)
...
@@ -2267,7 +2295,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2267,7 +2295,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
raise
NotImplementedError
(
raise
NotImplementedError
(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
...
@@ -2656,39 +2683,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
...
@@ -2656,39 +2683,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs
.
sequences
=
output_values
outputs
.
sequences
=
output_values
return
outputs
return
outputs
else
:
else
:
return
output_values
return
output_values
\ No newline at end of file
def
get_unconditional_inputs
(
self
,
num_samples
=
1
):
"""
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
Args:
num_samples (int, *optional*):
Number of audio samples to unconditionally generate.
max_new_tokens (int, *optional*):
Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
longer inference (since more audio tokens need to be generated per sample).
Example:
```python
>>> from transformers import ParlerTTSForConditionalGeneration
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
>>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
```"""
last_hidden_state
=
torch
.
zeros
(
(
num_samples
,
1
,
self
.
config
.
text_encoder
.
hidden_size
),
device
=
self
.
device
,
dtype
=
self
.
dtype
)
attention_mask
=
torch
.
zeros
((
num_samples
,
1
),
device
=
self
.
device
,
dtype
=
torch
.
long
)
return
ParlerTTSUnconditionalInput
(
encoder_outputs
=
(
last_hidden_state
,),
attention_mask
=
attention_mask
,
guidance_scale
=
1.0
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment