Unverified Commit ee4fb326 authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Fix M4T weights tying (#27395)

fix seamless m4t weights tying
parent e107ae36
......@@ -2785,6 +2785,12 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel):
self.text_decoder.embed_tokens = value
self.shared = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING)
def forward(
self,
......@@ -3077,6 +3083,11 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
def set_input_embeddings(self, value):
self.text_decoder.embed_tokens = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING)
def forward(
self,
......@@ -3384,6 +3395,12 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel):
self.text_decoder.embed_tokens = value
self.shared = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_TEXT_INPUTS_DOCSTRING)
def forward(
self,
......@@ -3740,6 +3757,11 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel):
def set_input_embeddings(self, value):
self.text_decoder.embed_tokens = value
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_SPEECH_INPUTS_DOCSTRING)
def forward(
self,
......@@ -4135,6 +4157,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.text_encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.text_decoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.lm_head, self.shared)
@add_start_docstrings_to_model_forward(M4T_MODEL_INPUTS_DOCSTRING)
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment