Commit 813df4d2 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

update modeling code with prompt concat

parent ca2cd16d
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration
\ No newline at end of file
......@@ -137,6 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information.
Args:
prompt_embed_dim (`int`, *optional*, defaults to 1024):
Dimensionality of the prompt embedding layer.
kwargs (*optional*):
Dictionary of keyword arguments. Notably:
......@@ -187,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
model_type = "stable_speech"
is_composition = True
def __init__(self, **kwargs):
def __init__(self, prompt_embed_dim=1024, **kwargs):
super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
......@@ -200,6 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder")
self.prompt_embed_dim = prompt_embed_dim
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config)
......
......@@ -689,6 +689,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -724,6 +726,22 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
input_shape = inputs_embeds.size()[:-1]
# TODO: verify if prompt attention mask is required
# As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
elif prompt_attention_mask is not None:
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)])
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
......@@ -862,6 +880,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -884,6 +904,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
......@@ -951,6 +973,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
......@@ -962,7 +986,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
......@@ -976,6 +1000,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
......@@ -992,7 +1018,17 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss = None
if labels is not None:
raise NotImplementedError("Training is not implemented for StableSpeech.")
# since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
loss = loss_fct(logits.transpose(1,3), labels)
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
......@@ -1016,6 +1052,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
prompt_hidden_states=None,
prompt_attention_mask=None,
head_mask=None,
cross_attn_head_mask=None,
past_key_values=None,
......@@ -1040,15 +1078,30 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
input_ids = input_ids.repeat((2, 1))
if attention_mask is not None:
attention_mask = attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = torch.concatenate(
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0
)
if prompt_attention_mask is not None:
prompt_attention_mask = torch.concatenate(
prompt_attention_mask, torch.zeros_like(prompt_attention_mask), dim=0
)
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask,
"prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask,
"head_mask": head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"past_key_values": past_key_values,
......@@ -1483,6 +1536,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
and self.decoder.config.cross_attention_hidden_size is None
):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
# prompt embeddings
self.embed_prompts = nn.Embedding(config.prompt_embed_dim, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None:
raise ValueError(
......@@ -1496,8 +1553,19 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
)
# tie text encoder, decoder weights if config set accordingly
self.tie_weights()
# Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly
self.post_init()
def _init_weights(self, module):
std = self.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def tie_weights(self):
# tie text encoder & decoder if needed
......@@ -1768,6 +1836,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
......@@ -1844,6 +1915,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
if prompt_hidden_states is None:
if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
# TODO: do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states?
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right(
......@@ -1876,29 +1952,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
use_cache=use_cache,
past_key_values=past_key_values,
return_dict=return_dict,
labels=labels,
**kwargs_decoder,
)
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
if loss is not None:
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return decoder_outputs + (encoder_hidden_states,)
return Seq2SeqLMOutput(
loss=loss,
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
......@@ -1917,6 +1987,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
head_mask=None,
decoder_attention_mask=None,
decoder_head_mask=None,
prompt_hidden_states=None,
prompt_attention_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
......@@ -1940,6 +2012,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
decoder_input_ids = decoder_input_ids.repeat((2, 1))
if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
# TODO: ? we probably don't want to keep guidance scale here ? different task than musicgeneration
prompt_hidden_states = torch.concatenate([prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
......@@ -1952,6 +2027,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
......@@ -2058,6 +2137,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
return model_kwargs
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids)
return model_kwargs
def _prepare_audio_encoder_kwargs_for_generation(
self, input_values, model_kwargs, model_input_name: Optional[str] = None
......@@ -2110,6 +2193,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
raise NotImplementedError(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
......@@ -2143,6 +2227,16 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
batch_size = value.shape[0]
break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
@torch.no_grad()
def generate(
......@@ -2277,6 +2371,13 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_input_name,
guidance_scale=generation_config.guidance_scale,
)
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
# `prompt_hidden_states` are created and added to `model_kwargs`
model_kwargs = self._prepare_prompt_kwargs_for_generation(
model_kwargs["prompt_input_ids"],
model_kwargs,
)
if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
......@@ -2455,6 +2556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
def get_unconditional_inputs(self, num_samples=1):
"""
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment