Commit 018391a6 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'untie_embeddings' into 'main'

Untie Embeddings

See merge request ADLR/megatron-lm!558
parents 4e891fe9 1e2a0405
...@@ -557,6 +557,8 @@ def _add_network_size_args(parser): ...@@ -557,6 +557,8 @@ def _add_network_size_args(parser):
dest='bert_binary_head') dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None, group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)') help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
return parser return parser
......
...@@ -132,6 +132,9 @@ class BertModel(MegatronModule): ...@@ -132,6 +132,9 @@ class BertModel(MegatronModule):
super(BertModel, self).__init__() super(BertModel, self).__init__()
args = get_args() args = get_args()
# TODO this option is not yet implemented in BERT
assert args.untie_embeddings_and_output_weights is False
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head self.add_binary_head = add_binary_head
self.parallel_output = parallel_output self.parallel_output = parallel_output
......
...@@ -50,13 +50,14 @@ class GPTModel(MegatronModule): ...@@ -50,13 +50,14 @@ class GPTModel(MegatronModule):
parallel_output=True, parallel_output=True,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
super(GPTModel, self).__init__()
args = get_args() args = get_args()
super(GPTModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.parallel_output = parallel_output self.parallel_output = parallel_output
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -68,6 +69,7 @@ class GPTModel(MegatronModule): ...@@ -68,6 +69,7 @@ class GPTModel(MegatronModule):
pre_process=self.pre_process, pre_process=self.pre_process,
post_process=self.post_process) post_process=self.post_process)
if not args.untie_embeddings_and_output_weights:
self.initialize_word_embeddings(init_method_normal) self.initialize_word_embeddings(init_method_normal)
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
...@@ -90,7 +92,7 @@ class GPTModel(MegatronModule): ...@@ -90,7 +92,7 @@ class GPTModel(MegatronModule):
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(),
self.parallel_output, self.parallel_output,
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
...@@ -103,7 +105,7 @@ class GPTModel(MegatronModule): ...@@ -103,7 +105,7 @@ class GPTModel(MegatronModule):
= self.language_model.state_dict_for_save_checkpoint( = self.language_model.state_dict_for_save_checkpoint(
prefix=prefix, keep_vars=keep_vars) prefix=prefix, keep_vars=keep_vars)
# Save word_embeddings. # Save word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
state_dict_[self._word_embeddings_for_head_key] \ state_dict_[self._word_embeddings_for_head_key] \
= self.word_embeddings.state_dict(prefix=prefix, = self.word_embeddings.state_dict(prefix=prefix,
keep_vars=keep_vars) keep_vars=keep_vars)
...@@ -113,7 +115,7 @@ class GPTModel(MegatronModule): ...@@ -113,7 +115,7 @@ class GPTModel(MegatronModule):
"""Customized load.""" """Customized load."""
# Load word_embeddings. # Load word_embeddings.
if self.post_process and not self.pre_process: if self.post_process and not self.pre_process and not self.untie_embeddings_and_output_weights:
self.word_embeddings.load_state_dict( self.word_embeddings.load_state_dict(
state_dict[self._word_embeddings_for_head_key], strict=strict) state_dict[self._word_embeddings_for_head_key], strict=strict)
if self._language_model_key in state_dict: if self._language_model_key in state_dict:
......
...@@ -336,8 +336,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -336,8 +336,10 @@ class TransformerLanguageModel(MegatronModule):
add_pooler=False, add_pooler=False,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
super(TransformerLanguageModel, self).__init__()
args = get_args() args = get_args()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
if args.untie_embeddings_and_output_weights: assert not add_decoder
super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
...@@ -350,6 +352,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -350,6 +352,7 @@ class TransformerLanguageModel(MegatronModule):
self.decoder_attn_mask_type = decoder_attn_mask_type self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
self.encoder_hidden_state = None self.encoder_hidden_state = None
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
# Embeddings. # Embeddings.
if self.pre_process: if self.pre_process:
...@@ -434,6 +437,14 @@ class TransformerLanguageModel(MegatronModule): ...@@ -434,6 +437,14 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
if self.untie_embeddings_and_output_weights:
self.output_layer = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.padded_vocab_size,
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method=self.init_method)
self._output_layer_key = 'output_layer'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()""" """ See megatron.model.transformer.set_input_tensor()"""
...@@ -566,6 +577,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -566,6 +577,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix, = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars) keep_vars=keep_vars)
if self.untie_embeddings_and_output_weights:
state_dict_[self._output_layer_key] \
= self.output_layer.state_dict(prefix=prefix, keep_vars=keep_vars)
if self.add_decoder: if self.add_decoder:
state_dict_[self._decoder_key] \ state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix, = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
...@@ -621,6 +636,11 @@ class TransformerLanguageModel(MegatronModule): ...@@ -621,6 +636,11 @@ class TransformerLanguageModel(MegatronModule):
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
if self.untie_embeddings_and_output_weights:
assert 'output_layer' in state_dict, \
'could not find data for output_layer in the checkpoint'
self.output_layer.load_state_dict(state_dict[self._output_layer_key],
strict=strict)
# Decoder. # Decoder.
if self.add_decoder: if self.add_decoder:
assert 'decoder' in state_dict, \ assert 'decoder' in state_dict, \
......
...@@ -70,8 +70,7 @@ class MegatronModule(torch.nn.Module): ...@@ -70,8 +70,7 @@ class MegatronModule(torch.nn.Module):
# 3. In the training loop, before an all-reduce between the grads of # 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage() and \ if mpu.is_pipeline_last_stage() and not self.pre_process:
not self.pre_process:
assert not mpu.is_pipeline_first_stage() assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment