"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "adb2503ea3f8b237b3423f69b5371d7c3c00c1f1"
Unverified Commit b9f12bed authored by smelm's avatar smelm Committed by GitHub
Browse files

Only call get_output_embeddings when tie_word_embeddings is set (#16667)



This avoids an unnecessary call and avoids problems during
initialization of class hierarchies.
Co-authored-by: default avatarSamuel Melm <samuel.melm@stud.uni-heidelberg.de>
parent 924484ee
...@@ -891,9 +891,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -891,9 +891,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
weights instead. weights instead.
""" """
output_embeddings = self.get_output_embeddings() if getattr(self.config, "tie_word_embeddings", True):
if output_embeddings is not None and getattr(self.config, "tie_word_embeddings", True): output_embeddings = self.get_output_embeddings()
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings()) if output_embeddings is not None:
self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
if hasattr(self, self.base_model_prefix): if hasattr(self, self.base_model_prefix):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment