Commit 18b26ec6 authored by Sandeep Subramanian's avatar Sandeep Subramanian Committed by Jimmy Zhang
Browse files

Save word embeddings for head only only if embeddings are not untied

parent f11b4c99
...@@ -105,7 +105,7 @@ class GPTModel(MegatronModule): ...@@ -105,7 +105,7 @@ class GPTModel(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars) prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(prefix=prefix, = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars) keep_vars=keep_vars)
...@@ -115,7 +115,7 @@ class GPTModel(MegatronModule): ...@@ -115,7 +115,7 @@ class GPTModel(MegatronModule):
"""Customized load.""" """Customized load."""
# Load word_embeddings. # Load word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment