Commit 6b50a8c6 authored by Jimmy Zhang's avatar Jimmy Zhang
Browse files

assertion check for T5 and untied embeddings

parent 18b26ec6
...@@ -328,6 +328,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -328,6 +328,7 @@ class TransformerLanguageModel(MegatronModule):
post_process=True): post_process=True):
args = get_args() args = get_args()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5. # TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if args.untie_embeddings_and_output_weights: assert not add_decoder
super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights) super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.pre_process = pre_process self.pre_process = pre_process
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment