"tests/vscode:/vscode.git/clone" did not exist on "3b8f445f4b723e2dd36cffc0dd4dfd48d2258986"
Commit 018391a6 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'untie_embeddings' into 'main'

Untie Embeddings

See merge request ADLR/megatron-lm!558
parents 4e891fe9 1e2a0405
......@@ -557,6 +557,8 @@ def _add_network_size_args(parser):
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
return parser
......
......@@ -132,6 +132,9 @@ class BertModel(MegatronModule):
super(BertModel, self).__init__()
args = get_args()
# TODO this option is not yet implemented in BERT
assert args.untie_embeddings_and_output_weights is False
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
......
......@@ -50,13 +50,14 @@ class GPTModel(MegatronModule):
parallel_output=True,
pre_process=True,
post_process=True):
super(GPTModel, self).__init__()
args = get_args()
super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.parallel_output = parallel_output
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
......@@ -67,8 +68,9 @@ class GPTModel(MegatronModule):
args.num_layers),
pre_process=self.pre_process,
post_process=self.post_process)
self.initialize_word_embeddings(init_method_normal)
if not args.untie_embeddings_and_output_weights:
self.initialize_word_embeddings(init_method_normal)
def set_input_tensor(self, input_tensor):
"""See megatron.model.transformer.set_input_tensor()"""
......@@ -90,7 +92,7 @@ class GPTModel(MegatronModule):
if self.post_process:
return post_language_model_processing(
lm_output, labels,
self.word_embeddings_weight(),
self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(),
self.parallel_output,
self.fp16_lm_cross_entropy)
else:
......@@ -103,7 +105,7 @@ class GPTModel(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings.
if self.post_process and not self.pre_process:
if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars)
......@@ -113,7 +115,7 @@ class GPTModel(MegatronModule):
"""Customized load."""
# Load word_embeddings.
if self.post_process and not self.pre_process:
if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict:
......
......@@ -336,8 +336,10 @@ class TransformerLanguageModel(MegatronModule):
add_pooler=False,
pre_process=True,
post_process=True):
super(TransformerLanguageModel, self).__init__()
args = get_args()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if args.untie_embeddings_and_output_weights: assert not add_decoder
super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.pre_process = pre_process
self.post_process = post_process
......@@ -350,6 +352,7 @@ class TransformerLanguageModel(MegatronModule):
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
# Embeddings.
if self.pre_process:
......@@ -434,6 +437,14 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
if self.untie_embeddings_and_output_weights:
self.output_layer = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.padded_vocab_size,
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method=self.init_method)
self._output_layer_key = 'output_layer'
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
......@@ -566,6 +577,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.untie_embeddings_and_output_weights:
state_dict_[self._output_layer_key] \
= self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
......@@ -621,6 +636,11 @@ class TransformerLanguageModel(MegatronModule):
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
if self.untie_embeddings_and_output_weights:
assert 'output_layer' in state_dict, \
'could not find data for output_layer in the checkpoint'
self.output_layer.load_state_dict(state_dict[self._output_layer_key],
strict=strict)
# Decoder.
if self.add_decoder:
assert 'decoder' in state_dict, \
......
......@@ -70,8 +70,7 @@ class MegatronModule(torch.nn.Module):
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage() and \
not self.pre_process:
if mpu.is_pipeline_last_stage() and not self.pre_process:
assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment