Commit a4f464fe authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

clean modeling code

parent 43087d4a
......@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24):
......@@ -138,7 +138,8 @@ class ParlerTTSConfig(PretrainedConfig):
Args:
vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt # TODO.
Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
......
......@@ -219,7 +219,6 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed
if seq_len > self.weights.size(0):
# TODO: doesn't work
self.make_weights(seq_len + self.offset, self.embedding_dim)
return self.weights.index_select(0, position_ids.view(-1)).detach()
......@@ -632,6 +631,25 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`.
prompt_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input prompt sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
prompt_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding prompt token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
......@@ -683,6 +701,16 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of prompt hidden-states at the output of the initial embedding layer. Concatenated to the input embeds.
prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of prompt input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
......@@ -738,7 +766,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
# TODO: not right dim
# TODO(YL): actually doesn't need the +1 if initialized correctly. Too late to change now.
embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
......@@ -769,8 +797,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -978,8 +1006,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -1071,8 +1099,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -1088,7 +1116,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
# TODO: delay_pattern_mask
Returns:
"""
......@@ -1263,7 +1290,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
**kwargs,
):
"""
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}>
......@@ -1504,15 +1530,20 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = outputs.sequences
else:
output_ids = outputs
# apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[
(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id)
& (model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id)
].reshape(batch_size, self.decoder.num_codebooks, -1)
_, mask = self.build_delay_pattern_mask(
input_ids,
bos_token_id=generation_config.bos_token_id,
pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.num_codebooks, -1)
if generation_config.return_dict_in_generate:
outputs.sequences = output_ids
......@@ -1856,7 +1887,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
)
if "config" not in kwargs_decoder:
# TODO: reput AutoConfig once added to transformers
decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
......@@ -1906,9 +1936,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_input_ids: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
......@@ -1989,10 +2019,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if prompt_hidden_states is None:
if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
# TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1, 2)
......@@ -2267,7 +2295,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
raise NotImplementedError(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
......@@ -2656,39 +2683,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs.sequences = output_values
return outputs
else:
return output_values
def get_unconditional_inputs(self, num_samples=1):
"""
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
Args:
num_samples (int, *optional*):
Number of audio samples to unconditionally generate.
max_new_tokens (int, *optional*):
Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
longer inference (since more audio tokens need to be generated per sample).
Example:
```python
>>> from transformers import ParlerTTSForConditionalGeneration
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
>>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
```"""
last_hidden_state = torch.zeros(
(num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
)
attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
return ParlerTTSUnconditionalInput(
encoder_outputs=(last_hidden_state,),
attention_mask=attention_mask,
guidance_scale=1.0,
)
return output_values
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment