Commit a4f464fe authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

clean modeling code

parent 43087d4a
...@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig): ...@@ -40,7 +40,7 @@ class ParlerTTSDecoderConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 2049): vocab_size (`int`, *optional*, defaults to 2049):
Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be Vocabulary size of the ParlerTTSDecoder model. Defines the number of different tokens that can be
represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`]. represented by the `inputs_ids` passed when calling [`ParlerTTSDecoder`].
hidden_size (`int`, *optional*, defaults to 1024): hidden_size (`int`, *optional*, defaults to 1024):
Dimensionality of the layers and the pooler layer. Dimensionality of the layers and the pooler layer.
num_hidden_layers (`int`, *optional*, defaults to 24): num_hidden_layers (`int`, *optional*, defaults to 24):
...@@ -138,7 +138,8 @@ class ParlerTTSConfig(PretrainedConfig): ...@@ -138,7 +138,8 @@ class ParlerTTSConfig(PretrainedConfig):
Args: Args:
vocab_size (`int`, *optional*, defaults to 1024): vocab_size (`int`, *optional*, defaults to 1024):
Vocabulary size of the prompt # TODO. Vocabulary size of the prompt token ids. Defines the number of different tokens that can be
represented by the `prompt_inputs_ids`.
kwargs (*optional*): kwargs (*optional*):
Dictionary of keyword arguments. Notably: Dictionary of keyword arguments. Notably:
......
...@@ -219,7 +219,6 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module): ...@@ -219,7 +219,6 @@ class ParlerTTSSinusoidalPositionalEmbedding(nn.Module):
position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device) position_ids = (torch.arange(seq_len) + past_key_values_length).to(input_ids.device)
# expand embeddings if needed # expand embeddings if needed
if seq_len > self.weights.size(0): if seq_len > self.weights.size(0):
# TODO: doesn't work
self.make_weights(seq_len + self.offset, self.embedding_dim) self.make_weights(seq_len + self.offset, self.embedding_dim)
return self.weights.index_select(0, position_ids.view(-1)).detach() return self.weights.index_select(0, position_ids.view(-1)).detach()
...@@ -632,6 +631,25 @@ MUSICGEN_INPUTS_DOCSTRING = r""" ...@@ -632,6 +631,25 @@ MUSICGEN_INPUTS_DOCSTRING = r"""
If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value If `decoder_input_ids` and `decoder_inputs_embeds` are both unset, `decoder_inputs_embeds` takes the value
of `inputs_embeds`. of `inputs_embeds`.
prompt_input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input prompt sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
prompt_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding prompt token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `prompt_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `prompt_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (`bool`, *optional*): use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`). `past_key_values`).
...@@ -683,6 +701,16 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r""" ...@@ -683,6 +701,16 @@ MUSICGEN_DECODER_INPUTS_DOCSTRING = r"""
- 1 for tokens that are **not masked**, - 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
prompt_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`, *optional*):
Sequence of prompt hidden-states at the output of the initial embedding layer. Concatenated to the input embeds.
prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, encoder_sequence_length)`, *optional*):
Mask to avoid performing cross-attention on padding tokens indices of prompt input_ids. Mask values
selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask) [What are attention masks?](../glossary#attention-mask)
head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
...@@ -738,7 +766,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -738,7 +766,7 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
self.num_codebooks = config.num_codebooks self.num_codebooks = config.num_codebooks
self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 self.embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
# TODO: not right dim # TODO(YL): actually doesn't need the +1 if initialized correctly. Too late to change now.
embed_dim = config.vocab_size + 1 # + 1 for pad token id embed_dim = config.vocab_size + 1 # + 1 for pad token id
self.embed_tokens = nn.ModuleList( self.embed_tokens = nn.ModuleList(
[nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)] [nn.Embedding(embed_dim, config.hidden_size) for _ in range(config.num_codebooks)]
...@@ -769,8 +797,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel): ...@@ -769,8 +797,8 @@ class ParlerTTSDecoder(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -978,8 +1006,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel): ...@@ -978,8 +1006,8 @@ class ParlerTTSModel(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -1071,8 +1099,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1071,8 +1099,8 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -1088,7 +1116,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1088,7 +1116,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
# TODO: delay_pattern_mask
Returns: Returns:
""" """
...@@ -1263,7 +1290,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1263,7 +1290,6 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
**kwargs, **kwargs,
): ):
""" """
# TODO: adapt this generate with latest change
Generates sequences of token ids for models with a language modeling head. Generates sequences of token ids for models with a language modeling head.
<Tip warning={true}> <Tip warning={true}>
...@@ -1504,15 +1530,20 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel): ...@@ -1504,15 +1530,20 @@ class ParlerTTSForCausalLM(ParlerTTSPreTrainedModel):
output_ids = outputs.sequences output_ids = outputs.sequences
else: else:
output_ids = outputs output_ids = outputs
# apply the pattern mask to the final ids # apply the pattern mask to the final ids
output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"]) output_ids = self.apply_delay_pattern_mask(output_ids, model_kwargs["delay_pattern_mask"])
# revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask # revert the pattern delay mask by filtering the eos and bos token ids from the delay pattern mask
output_ids = output_ids[ _, mask = self.build_delay_pattern_mask(
(model_kwargs["delay_pattern_mask"] != generation_config.bos_token_id) input_ids,
& (model_kwargs["delay_pattern_mask"] != generation_config.eos_token_id) bos_token_id=generation_config.bos_token_id,
].reshape(batch_size, self.decoder.num_codebooks, -1) pad_token_id=generation_config.pad_token_id,
max_length=output_ids.shape[1],
)
mask = (mask != generation_config.bos_token_id) & (mask != generation_config.pad_token_id)
output_ids = output_ids[mask].reshape(batch_size, self.num_codebooks, -1)
if generation_config.return_dict_in_generate: if generation_config.return_dict_in_generate:
outputs.sequences = output_ids outputs.sequences = output_ids
...@@ -1856,7 +1887,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1856,7 +1887,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
) )
if "config" not in kwargs_decoder: if "config" not in kwargs_decoder:
# TODO: reput AutoConfig once added to transformers
decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained( decoder_config, kwargs_decoder = ParlerTTSDecoderConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
) )
...@@ -1906,9 +1936,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1906,9 +1936,9 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_input_ids: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings prompt_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings prompt_hidden_states: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -1989,10 +2019,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -1989,10 +2019,8 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
if prompt_hidden_states is None: if prompt_hidden_states is None:
if prompt_input_ids is not None: if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids) prompt_hidden_states = self.embed_prompts(prompt_input_ids)
# TODO: verify prompt_attention_mask
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
# TODO: verify it does what's expected
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id labels, self.config.pad_token_id, self.config.decoder_start_token_id
).transpose(1, 2) ).transpose(1, 2)
...@@ -2267,7 +2295,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2267,7 +2295,6 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id).transpose(1, 2)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
raise NotImplementedError( raise NotImplementedError(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
...@@ -2656,39 +2683,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel): ...@@ -2656,39 +2683,4 @@ class ParlerTTSForConditionalGeneration(PreTrainedModel):
outputs.sequences = output_values outputs.sequences = output_values
return outputs return outputs
else: else:
return output_values return output_values
\ No newline at end of file
def get_unconditional_inputs(self, num_samples=1):
"""
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
Args:
num_samples (int, *optional*):
Number of audio samples to unconditionally generate.
max_new_tokens (int, *optional*):
Number of tokens to generate for each sample. More tokens means longer audio samples, at the expense of
longer inference (since more audio tokens need to be generated per sample).
Example:
```python
>>> from transformers import ParlerTTSForConditionalGeneration
>>> model = ParlerTTSForConditionalGeneration.from_pretrained("facebook/parler_tts-small")
>>> # get the unconditional (or 'null') inputs for the model
>>> unconditional_inputs = model.get_unconditional_inputs(num_samples=1)
>>> audio_samples = model.generate(**unconditional_inputs, max_new_tokens=256)
```"""
last_hidden_state = torch.zeros(
(num_samples, 1, self.config.text_encoder.hidden_size), device=self.device, dtype=self.dtype
)
attention_mask = torch.zeros((num_samples, 1), device=self.device, dtype=torch.long)
return ParlerTTSUnconditionalInput(
encoder_outputs=(last_hidden_state,),
attention_mask=attention_mask,
guidance_scale=1.0,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment