Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
parler-tts
Commits
a4f464fe
Commit
a4f464fe
authored
Apr 08, 2024
by
Yoach Lacombe
Browse files
clean modeling code
parent
43087d4a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
60 deletions
+53
-60
parler_tts/configuration_parler_tts.py
parler_tts/configuration_parler_tts.py
+3
-2
parler_tts/modeling_parler_tts.py
parler_tts/modeling_parler_tts.py
+50
-58
No files found.
parler_tts/configuration_parler_tts.py
View file @
a4f464fe
...
...
@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
...
...
@@ -138,7 +138,8 @@ class ParlerTTSConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt # TODO.
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
...
...
parler_tts/modeling_parler_tts.py
View file @
a4f464fe
...
...
@@ -219,7 +219,6 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
position_ids
=
(
torch
.
arange
(
seq_len
)
+
past_key_values_length
).
to
(
input_ids
.
device
)
# expand embeddings if needed
if
seq_len
>
self
.
weights
.
size
(
0
):
# TODO: doesn't work
self
.
make_weights
(
seq_len
+
self
.
offset
,
self
.
embedding_dim
)
return
self
.
weights
.
index_select
(
0
,
position_ids
.
view
(
-
1
)).
detach
()
...
...
@@ -632,6 +631,25 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
prompt_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input prompt sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
prompt_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding prompt token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
...
...
@@ -683,6 +701,16 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of prompt hidden-states at the output of the initial embedding layer. Concatenated to the input embeds.
prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of prompt input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
...
...
@@ -738,7 +766,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
self
.
num_codebooks
=
config
.
num_codebooks
self
.
embed_scale
=
math
.
sqrt
(
config
.
hidden_size
)
if
config
.
scale_embedding
else
1.0
# TODO
: not right dim
# TODO
(YL): actually doesn't need the +1 if initialized correctly. Too late to change now.
embed_dim
=
config
.
vocab_size
+
1
# + 1 for pad token id
self
.
embed_tokens
=
nn
.
ModuleList
(
[
nn
.
Embedding
(
embed_dim
,
config
.
hidden_size
)
for
_
in
range
(
config
.
num_codebooks
)]
...
...
@@ -769,8 +797,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
...
@@ -978,8 +1006,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
...
@@ -1071,8 +1099,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
cross_attn_head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
FloatTensor
]]]
=
None
,
...
...
@@ -1088,7 +1116,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
# TODO: delay_pattern_mask
Returns:
"""
...
...
@@ -1263,7 +1290,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
**
kwargs
,
):
"""
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
...
...
@@ -1504,15 +1530,20 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids
=
outputs
.
sequences
else
:
output_ids
=
outputs
# apply the pattern mask to the final ids
output_ids
=
self
.
apply_delay_pattern_mask
(
output_ids
,
model_kwargs
[
"delay_pattern_mask"
])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids
=
output_ids
[
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
bos_token_id
)
&
(
model_kwargs
[
"delay_pattern_mask"
]
!=
generation_config
.
eos_token_id
)
].
reshape
(
batch_size
,
self
.
decoder
.
num_codebooks
,
-
1
)
_
,
mask
=
self
.
build_delay_pattern_mask
(
input_ids
,
bos_token_id
=
generation_config
.
bos_token_id
,
pad_token_id
=
generation_config
.
pad_token_id
,
max_length
=
output_ids
.
shape
[
1
],
)
mask
=
(
mask
!=
generation_config
.
bos_token_id
)
&
(
mask
!=
generation_config
.
pad_token_id
)
output_ids
=
output_ids
[
mask
].
reshape
(
batch_size
,
self
.
num_codebooks
,
-
1
)
if
generation_config
.
return_dict_in_generate
:
outputs
.
sequences
=
output_ids
...
...
@@ -1856,7 +1887,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
)
if
"config"
not
in
kwargs_decoder
:
# TODO: reput AutoConfig once added to transformers
decoder_config
,
kwargs_decoder
=
ParlerTTSDecoderConfig
.
from_pretrained
(
decoder_pretrained_model_name_or_path
,
**
kwargs_decoder
,
return_unused_kwargs
=
True
)
...
...
@@ -1906,9 +1936,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values
:
Tuple
[
Tuple
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
decoder_inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_input_ids
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
# TODO: add to docstrings
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
# TODO: add to docstrings
prompt_input_ids
:
Optional
[
torch
.
FloatTensor
]
=
None
,
prompt_attention_mask
:
Optional
[
torch
.
LongTensor
]
=
None
,
prompt_hidden_states
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
...
...
@@ -1989,10 +2019,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if
prompt_hidden_states
is
None
:
if
prompt_input_ids
is
not
None
:
prompt_hidden_states
=
self
.
embed_prompts
(
prompt_input_ids
)
# TODO: verify prompt_attention_mask
if
(
labels
is
not
None
)
and
(
decoder_input_ids
is
None
and
decoder_inputs_embeds
is
None
):
# TODO: verify it does what's expected
decoder_input_ids
=
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
...
...
@@ -2267,7 +2295,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return
shift_tokens_right
(
labels
,
self
.
config
.
pad_token_id
,
self
.
config
.
decoder_start_token_id
).
transpose
(
1
,
2
)
def
resize_token_embeddings
(
self
,
*
args
,
**
kwargs
):
# TODO: now it's possible with prompt_embeddings
raise
NotImplementedError
(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
...
...
@@ -2656,39 +2683,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs
.
sequences
=
output_values
return
outputs
else
:
return
output_values
def
get_unconditional_inputs
(
self
,
num_samples
=
1
):
"""
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
Args:
num_samples (int, *optional*):
Number of audio samples to unconditionally generate.
max_new_tokens (int, *optional*):
Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
longer inference (since more audio tokens need to be generated per sample).
Example:
```python
>>> from transformers import ParlerTTSForConditionalGeneration
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
>>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
```"""
last_hidden_state
=
torch
.
zeros
(
(
num_samples
,
1
,
self
.
config
.
text_encoder
.
hidden_size
),
device
=
self
.
device
,
dtype
=
self
.
dtype
)
attention_mask
=
torch
.
zeros
((
num_samples
,
1
),
device
=
self
.
device
,
dtype
=
torch
.
long
)
return
ParlerTTSUnconditionalInput
(
encoder_outputs
=
(
last_hidden_state
,),
attention_mask
=
attention_mask
,
guidance_scale
=
1.0
,
)
return
output_values
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment