Commit 82c7ba57 authored by MaximumEntropy's avatar MaximumEntropy Committed by Jimmy Zhang
Browse files

Initial commit for untied embeddings


Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>
parent fdd34a82
......@@ -530,6 +530,8 @@ def _add_network_size_args(parser):
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
return parser
......
......@@ -53,6 +53,7 @@ def initialize_model_parallel(
pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None,
untie_embeddings_and_output_weights: bool = False,
) -> None:
"""
Initialize model data parallel groups.
......@@ -93,6 +94,9 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder.
untie_embeddings_and_output_weights: whether to use separate embedding and output layer.
this affects the computation of embedding groups
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will
......@@ -200,13 +204,19 @@ def initialize_model_parallel(
# Setup embedding group (to exchange gradients between
# first and last stages).
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
if untie_embeddings_and_output_weights:
embedding_ranks = [ranks[0]]
else:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
if untie_embeddings_and_output_weights:
embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank],
ranks[-1]]
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank]]
......
......@@ -185,7 +185,8 @@ def _initialize_distributed():
mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank)
args.pipeline_model_parallel_split_rank,
args.untie_embeddings_and_output_weights)
if args.rank == 0:
print(f'> initialized tensor model parallel with size '
f'{mpu.get_tensor_model_parallel_world_size()}')
......
......@@ -57,6 +57,7 @@ class GPTModel(MegatronModule):
self.pre_process = pre_process
self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes,
......@@ -90,7 +91,7 @@ class GPTModel(MegatronModule):
if self.post_process:
return post_language_model_processing(
lm_output, labels,
self.word_embeddings_weight(),
self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(),
self.parallel_output,
self.fp16_lm_cross_entropy)
else:
......
......@@ -326,8 +326,9 @@ class TransformerLanguageModel(MegatronModule):
add_pooler=False,
pre_process=True,
post_process=True):
super(TransformerLanguageModel, self).__init__()
args = get_args()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.pre_process = pre_process
self.post_process = post_process
......@@ -340,6 +341,7 @@ class TransformerLanguageModel(MegatronModule):
self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler
self.encoder_hidden_state = None
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
# Embeddings.
if self.pre_process:
......@@ -408,6 +410,14 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler'
if self.untie_embeddings_and_output_weights:
self.output_layer = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.padded_vocab_size,
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method=output_layer_init_method)
self._output_layer_key = 'output_layer'
def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""
......@@ -529,6 +539,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.untie_embeddings_and_output_weights:
state_dict_[self._output_layer_key] \
= self.output_layer.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_decoder:
state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
......@@ -584,6 +598,11 @@ class TransformerLanguageModel(MegatronModule):
'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict)
if self.untie_embeddings_and_output_weights:
assert 'output_layer' in state_dict, \
'could not find data for output_layer in the checkpoint'
self.output_layer.load_state_dict(state_dict[self._output_layer_key],
strict=strict)
# Decoder.
if self.add_decoder:
assert 'decoder' in state_dict, \
......
......@@ -71,7 +71,7 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if mpu.is_pipeline_last_stage() and \
not self.pre_process:
not self.pre_process and not self.untie_embeddings_and_output_weights:
assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment