"test/old-api/spectests.cpp" did not exist on "fa0af88dfeef3c6ed06296b34989d548032b13f0"
Commit 813df4d2 authored by Yoach Lacombe's avatar Yoach Lacombe
Browse files

update modeling code with prompt concat

parent ca2cd16d
from .configuration_stable_speech import StableSpeechConfig, StableSpeechDecoderConfig
from .modeling_stable_speech import StableSpeechForCausalLM, StableSpeechForConditionalGeneration
\ No newline at end of file
...@@ -137,6 +137,8 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -137,6 +137,8 @@ class StableSpeechConfig(PretrainedConfig):
documentation from [`PretrainedConfig`] for more information. documentation from [`PretrainedConfig`] for more information.
Args: Args:
prompt_embed_dim (`int`, *optional*, defaults to 1024):
Dimensionality of the prompt embedding layer.
kwargs (*optional*): kwargs (*optional*):
Dictionary of keyword arguments. Notably: Dictionary of keyword arguments. Notably:
...@@ -187,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -187,7 +189,7 @@ class StableSpeechConfig(PretrainedConfig):
model_type = "stable_speech" model_type = "stable_speech"
is_composition = True is_composition = True
def __init__(self, **kwargs): def __init__(self, prompt_embed_dim=1024, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs: if "text_encoder" not in kwargs or "audio_encoder" not in kwargs or "decoder" not in kwargs:
raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config") raise ValueError("Config has to be initialized with text_encoder, audio_encoder and decoder config")
...@@ -200,6 +202,7 @@ class StableSpeechConfig(PretrainedConfig): ...@@ -200,6 +202,7 @@ class StableSpeechConfig(PretrainedConfig):
decoder_config = kwargs.pop("decoder") decoder_config = kwargs.pop("decoder")
self.prompt_embed_dim = prompt_embed_dim
self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config) self.text_encoder = AutoConfig.for_model(text_encoder_model_type, **text_encoder_config)
self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config) self.audio_encoder = AutoConfig.for_model(audio_encoder_model_type, **audio_encoder_config)
self.decoder = StableSpeechDecoderConfig(**decoder_config) self.decoder = StableSpeechDecoderConfig(**decoder_config)
......
...@@ -689,6 +689,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -689,6 +689,8 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -724,6 +726,22 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel): ...@@ -724,6 +726,22 @@ class StableSpeechDecoder(StableSpeechPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
# if prompt_hidden_states, fuse to inputs_embeds and update input shape
if prompt_hidden_states is not None:
inputs_embeds = torch.cat([prompt_hidden_states, inputs_embeds], dim=1)
input_shape = inputs_embeds.size()[:-1]
# TODO: verify if prompt attention mask is required
# As it is, the masked ids from the prompt will still count in the positions embeddings
if prompt_attention_mask is not None and attention_mask is not None:
attention_mask = torch.cat([prompt_attention_mask, attention_mask], dim=1)
elif prompt_attention_mask is not None:
logger.warning_once(
"`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour."
)
attention_mask = torch.cat([prompt_attention_mask, torch.ones(input_shape, device=self.device, dtype=prompt_attention_mask.dtype)])
attention_mask = _prepare_4d_causal_attention_mask( attention_mask = _prepare_4d_causal_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length attention_mask, input_shape, inputs_embeds, past_key_values_length
...@@ -862,6 +880,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel): ...@@ -862,6 +880,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -884,6 +904,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel): ...@@ -884,6 +904,8 @@ class StableSpeechModel(StableSpeechPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
head_mask=head_mask, head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
...@@ -951,6 +973,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -951,6 +973,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.LongTensor] = None, encoder_attention_mask: Optional[torch.LongTensor] = None,
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
head_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask: Optional[torch.Tensor] = None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
...@@ -962,7 +986,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -962,7 +986,7 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length, num_codebooks)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
...@@ -976,6 +1000,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -976,6 +1000,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask=attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
head_mask=head_mask, head_mask=head_mask,
cross_attn_head_mask=cross_attn_head_mask, cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
...@@ -992,7 +1018,17 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -992,7 +1018,17 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
raise NotImplementedError("Training is not implemented for StableSpeech.") # since encoder hidden states have concatenated to hidden states, take the last hidden states corresponding to labels
logits = lm_logits[:,:,-labels.shape[1]:]
loss_fct = CrossEntropyLoss()
loss = torch.zeros([], device=self.device)
# per codebook cross-entropy
# -100 labels are ignored
# (bsz, vocab_size, seq_len, num_codebooks), (bsz, seq_len, num_codebooks)
labels = labels.masked_fill(labels == self.config.pad_token_id, -100)
loss = loss_fct(logits.transpose(1,3), labels)
# (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size) # (bsz, num_codebooks, seq_len, vocab_size) -> (bsz * num_codebooks, seq_len, vocab_size)
lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:]) lm_logits = lm_logits.reshape(-1, *lm_logits.shape[2:])
...@@ -1016,6 +1052,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1016,6 +1052,8 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
attention_mask=None, attention_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
prompt_hidden_states=None,
prompt_attention_mask=None,
head_mask=None, head_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
past_key_values=None, past_key_values=None,
...@@ -1040,15 +1078,30 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel): ...@@ -1040,15 +1078,30 @@ class StableSpeechForCausalLM(StableSpeechPreTrainedModel):
input_ids = input_ids.repeat((2, 1)) input_ids = input_ids.repeat((2, 1))
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.repeat((2, 1)) attention_mask = attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
prompt_hidden_states = torch.concatenate(
[prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0
)
if prompt_attention_mask is not None:
prompt_attention_mask = torch.concatenate(
prompt_attention_mask, torch.zeros_like(prompt_attention_mask), dim=0
)
if past_key_values is not None: if past_key_values is not None:
input_ids = input_ids[:, -1:] input_ids = input_ids[:, -1:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"encoder_hidden_states": encoder_hidden_states, "encoder_hidden_states": encoder_hidden_states,
"encoder_attention_mask": encoder_attention_mask, "encoder_attention_mask": encoder_attention_mask,
"prompt_hidden_states": prompt_hidden_states,
"prompt_attention_mask": prompt_attention_mask,
"head_mask": head_mask, "head_mask": head_mask,
"cross_attn_head_mask": cross_attn_head_mask, "cross_attn_head_mask": cross_attn_head_mask,
"past_key_values": past_key_values, "past_key_values": past_key_values,
...@@ -1483,6 +1536,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1483,6 +1536,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
and self.decoder.config.cross_attention_hidden_size is None and self.decoder.config.cross_attention_hidden_size is None
): ):
self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size) self.enc_to_dec_proj = nn.Linear(self.text_encoder.config.hidden_size, self.decoder.config.hidden_size)
# prompt embeddings
self.embed_prompts = nn.Embedding(config.prompt_embed_dim, self.decoder.config.hidden_size)
if self.text_encoder.get_output_embeddings() is not None: if self.text_encoder.get_output_embeddings() is not None:
raise ValueError( raise ValueError(
...@@ -1496,8 +1553,19 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1496,8 +1553,19 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
"following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350" "following discussion on GitHub: https://github.com/huggingface/transformers/issues/23350"
) )
# tie text encoder, decoder weights if config set accordingly # Initialize projection and embedding layers and tie text encoder and decoder weights if set accordingly
self.tie_weights() self.post_init()
def _init_weights(self, module):
std = self.config.initializer_factor
if isinstance(module, (nn.Linear, nn.Conv1d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def tie_weights(self): def tie_weights(self):
# tie text encoder & decoder if needed # tie text encoder & decoder if needed
...@@ -1768,6 +1836,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1768,6 +1836,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
past_key_values: Tuple[Tuple[torch.FloatTensor]] = None, past_key_values: Tuple[Tuple[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
prompt_input_ids: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
prompt_attention_mask: Optional[torch.LongTensor] = None, # TODO: add to docstrings
prompt_hidden_states: Optional[torch.FloatTensor] = None, # TODO: add to docstrings
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None, use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
...@@ -1844,6 +1915,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1844,6 +1915,11 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
if attention_mask is not None: if attention_mask is not None:
encoder_hidden_states = encoder_hidden_states * attention_mask[..., None] encoder_hidden_states = encoder_hidden_states * attention_mask[..., None]
if prompt_hidden_states is None:
if prompt_input_ids is not None:
prompt_hidden_states = self.embed_prompts(prompt_input_ids)
# TODO: do we do something with prompt_attention_mask ? e.g multiply it to prompt_hidden_states?
if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None): if (labels is not None) and (decoder_input_ids is None and decoder_inputs_embeds is None):
decoder_input_ids = shift_tokens_right( decoder_input_ids = shift_tokens_right(
...@@ -1876,29 +1952,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1876,29 +1952,23 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
prompt_hidden_states=prompt_hidden_states,
prompt_attention_mask=prompt_attention_mask,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
use_cache=use_cache, use_cache=use_cache,
past_key_values=past_key_values, past_key_values=past_key_values,
return_dict=return_dict, return_dict=return_dict,
labels=labels,
**kwargs_decoder, **kwargs_decoder,
) )
loss = None
if labels is not None:
logits = decoder_outputs.logits if return_dict else decoder_outputs[0]
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict: if not return_dict:
if loss is not None: return decoder_outputs + (encoder_hidden_states,)
return (loss,) + decoder_outputs + encoder_outputs
else:
return decoder_outputs + encoder_outputs
return Seq2SeqLMOutput( return Seq2SeqLMOutput(
loss=loss, loss=decoder_outputs.loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
...@@ -1917,6 +1987,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1917,6 +1987,8 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
head_mask=None, head_mask=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_head_mask=None, decoder_head_mask=None,
prompt_hidden_states=None,
prompt_attention_mask=None,
cross_attn_head_mask=None, cross_attn_head_mask=None,
use_cache=None, use_cache=None,
encoder_outputs=None, encoder_outputs=None,
...@@ -1940,6 +2012,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1940,6 +2012,9 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
decoder_input_ids = decoder_input_ids.repeat((2, 1)) decoder_input_ids = decoder_input_ids.repeat((2, 1))
if decoder_attention_mask is not None: if decoder_attention_mask is not None:
decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) decoder_attention_mask = decoder_attention_mask.repeat((2, 1))
if prompt_hidden_states is not None:
# TODO: ? we probably don't want to keep guidance scale here ? different task than musicgeneration
prompt_hidden_states = torch.concatenate([prompt_hidden_states, torch.zeros_like(prompt_hidden_states)], dim=0)
if past_key_values is not None: if past_key_values is not None:
past_length = past_key_values[0][0].shape[2] past_length = past_key_values[0][0].shape[2]
...@@ -1952,6 +2027,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -1952,6 +2027,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
remove_prefix_length = decoder_input_ids.shape[1] - 1 remove_prefix_length = decoder_input_ids.shape[1] - 1
decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] decoder_input_ids = decoder_input_ids[:, remove_prefix_length:]
# we only want to use prompt signal in the 1st generation step but keeping the attention mask
prompt_hidden_states = None
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
...@@ -2058,6 +2137,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2058,6 +2137,10 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state) model_kwargs["encoder_outputs"] = BaseModelOutput(last_hidden_state=last_hidden_state)
return model_kwargs return model_kwargs
def _prepare_prompt_kwargs_for_generation(self, prompt_input_ids, model_kwargs):
model_kwargs["prompt_hidden_states"] = self.embed_prompts(prompt_input_ids)
return model_kwargs
def _prepare_audio_encoder_kwargs_for_generation( def _prepare_audio_encoder_kwargs_for_generation(
self, input_values, model_kwargs, model_input_name: Optional[str] = None self, input_values, model_kwargs, model_input_name: Optional[str] = None
...@@ -2110,6 +2193,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2110,6 +2193,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id)
def resize_token_embeddings(self, *args, **kwargs): def resize_token_embeddings(self, *args, **kwargs):
# TODO: now it's possible with prompt_embeddings
raise NotImplementedError( raise NotImplementedError(
"Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the" "Resizing the embedding layers via the EncoderDecoderModel directly is not supported. Please use the"
" respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or" " respective methods of the wrapped objects (model.encoder.resize_token_embeddings(...) or"
...@@ -2143,6 +2227,16 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2143,6 +2227,16 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
batch_size = value.shape[0] batch_size = value.shape[0]
break break
return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * bos_token_id
def freeze_encoders(self, freeze_text_encoder=True):
if freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.text_encoder._requires_grad = False
for param in self.audio_encoder.parameters():
param.requires_grad = False
self.audio_encoder._requires_grad = False
@torch.no_grad() @torch.no_grad()
def generate( def generate(
...@@ -2277,6 +2371,13 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2277,6 +2371,13 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
model_input_name, model_input_name,
guidance_scale=generation_config.guidance_scale, guidance_scale=generation_config.guidance_scale,
) )
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs:
# `prompt_hidden_states` are created and added to `model_kwargs`
model_kwargs = self._prepare_prompt_kwargs_for_generation(
model_kwargs["prompt_input_ids"],
model_kwargs,
)
if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs: if "decoder_input_ids" not in model_kwargs and "input_values" in model_kwargs:
model_kwargs = self._prepare_audio_encoder_kwargs_for_generation( model_kwargs = self._prepare_audio_encoder_kwargs_for_generation(
...@@ -2455,6 +2556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel): ...@@ -2455,6 +2556,7 @@ class StableSpeechForConditionalGeneration(PreTrainedModel):
def get_unconditional_inputs(self, num_samples=1): def get_unconditional_inputs(self, num_samples=1):
""" """
# TODO: Remove ?
Helper function to get null inputs for unconditional generation, enabling the model to be used without the Helper function to get null inputs for unconditional generation, enabling the model to be used without the
feature extractor or tokenizer. feature extractor or tokenizer.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment