Commit f11b4c99 authored by Jimmy Zhang's avatar Jimmy Zhang
Browse files

disable embedding addreduce if untie_embeddings_and_output_weights

parent a3fbac58
......@@ -349,7 +349,6 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")
# Load retro args.
if args.retro_workdir:
retro_args_path = get_retro_args_path(args.retro_workdir)
......@@ -368,7 +367,6 @@ def validate_args(args, defaults={}):
if retro_args and args != retro_args:
_print_args("retro arguments", types.SimpleNamespace(**{k:v for k,v in vars(retro_args).items() if k.startswith("retro")}, rank=args.rank))
return args
......
......@@ -50,8 +50,8 @@ class GPTModel(MegatronModule):
parallel_output=True,
pre_process=True,
post_process=True):
super(GPTModel, self).__init__()
args = get_args()
super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.parallel_output = parallel_output
self.pre_process = pre_process
......@@ -68,8 +68,9 @@ class GPTModel(MegatronModule):
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
if not args.untie_embeddings_and_output_weights:
self.initialize_word_embeddings(init_method_normal)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment