Unverified Commit b9f12bed authored by smelm's avatar smelm Committed by GitHub
Browse files

Only call get_output_embeddings when tie_word_embeddings is set (#16667)



This avoids an unnecessary call and avoids problems during
initialization of class hierarchies.
Co-authored-by: default avatarSamuel Melm <samuel.melm@stud.uni-heidelberg.de>
parent 924484ee
......@@ -891,8 +891,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead.
"""
if getattr(self.config, "tie_word_embeddings", True):
output_embeddings = self.get_output_embeddings()
if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True):
if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment