"vscode:/vscode.git/clone" did not exist on "9430be7612f6f2c1bd363cf1c18b5e6d81b95fe0"
Commit 82c7ba57 authored by MaximumEntropy's avatar MaximumEntropy Committed by Jimmy Zhang
Browse files

Initial commit for untied embeddings


Signed-off-by: default avatarMaximumEntropy <sandeep.subramanian.1@umontreal.ca>
parent fdd34a82
...@@ -530,6 +530,8 @@ def _add_network_size_args(parser): ...@@ -530,6 +530,8 @@ def _add_network_size_args(parser):
dest='bert_binary_head') dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None, group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)') help='Number of Experts in Switch Transformer (None means no Switch)')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
return parser return parser
......
...@@ -53,6 +53,7 @@ def initialize_model_parallel( ...@@ -53,6 +53,7 @@ def initialize_model_parallel(
pipeline_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1,
virtual_pipeline_model_parallel_size: Optional[int] = None, virtual_pipeline_model_parallel_size: Optional[int] = None,
pipeline_model_parallel_split_rank: Optional[int] = None, pipeline_model_parallel_split_rank: Optional[int] = None,
untie_embeddings_and_output_weights: bool = False,
) -> None: ) -> None:
""" """
Initialize model data parallel groups. Initialize model data parallel groups.
...@@ -93,6 +94,9 @@ def initialize_model_parallel( ...@@ -93,6 +94,9 @@ def initialize_model_parallel(
pipeline_model_parallel_split_rank is 3, then ranks 0-2 pipeline_model_parallel_split_rank is 3, then ranks 0-2
will be the encoder and ranks 3-7 will be the decoder. will be the encoder and ranks 3-7 will be the decoder.
untie_embeddings_and_output_weights: whether to use separate embedding and output layer.
this affects the computation of embedding groups
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
the model pipeline. The present function will the model pipeline. The present function will
...@@ -200,10 +204,16 @@ def initialize_model_parallel( ...@@ -200,10 +204,16 @@ def initialize_model_parallel(
# Setup embedding group (to exchange gradients between # Setup embedding group (to exchange gradients between
# first and last stages). # first and last stages).
if len(ranks) > 1: if len(ranks) > 1:
if untie_embeddings_and_output_weights:
embedding_ranks = [ranks[0]]
else:
embedding_ranks = [ranks[0], ranks[-1]] embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]] position_embedding_ranks = [ranks[0]]
if pipeline_model_parallel_split_rank is not None: if pipeline_model_parallel_split_rank is not None:
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks: if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
if untie_embeddings_and_output_weights:
embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
else:
embedding_ranks = [ranks[0], embedding_ranks = [ranks[0],
ranks[pipeline_model_parallel_split_rank], ranks[pipeline_model_parallel_split_rank],
ranks[-1]] ranks[-1]]
......
...@@ -185,7 +185,8 @@ def _initialize_distributed(): ...@@ -185,7 +185,8 @@ def _initialize_distributed():
mpu.initialize_model_parallel(args.tensor_model_parallel_size, mpu.initialize_model_parallel(args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
args.virtual_pipeline_model_parallel_size, args.virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_split_rank) args.pipeline_model_parallel_split_rank,
args.untie_embeddings_and_output_weights)
if args.rank == 0: if args.rank == 0:
print(f'> initialized tensor model parallel with size ' print(f'> initialized tensor model parallel with size '
f'{mpu.get_tensor_model_parallel_world_size()}') f'{mpu.get_tensor_model_parallel_world_size()}')
......
...@@ -57,6 +57,7 @@ class GPTModel(MegatronModule): ...@@ -57,6 +57,7 @@ class GPTModel(MegatronModule):
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
self.language_model, self._language_model_key = get_language_model( self.language_model, self._language_model_key = get_language_model(
num_tokentypes=num_tokentypes, num_tokentypes=num_tokentypes,
...@@ -90,7 +91,7 @@ class GPTModel(MegatronModule): ...@@ -90,7 +91,7 @@ class GPTModel(MegatronModule):
if self.post_process: if self.post_process:
return post_language_model_processing( return post_language_model_processing(
lm_output, labels, lm_output, labels,
self.word_embeddings_weight(), self.language_model.output_layer.weight if self.untie_embeddings_and_output_weights else self.word_embeddings_weight(),
self.parallel_output, self.parallel_output,
self.fp16_lm_cross_entropy) self.fp16_lm_cross_entropy)
else: else:
......
...@@ -326,8 +326,9 @@ class TransformerLanguageModel(MegatronModule): ...@@ -326,8 +326,9 @@ class TransformerLanguageModel(MegatronModule):
add_pooler=False, add_pooler=False,
pre_process=True, pre_process=True,
post_process=True): post_process=True):
super(TransformerLanguageModel, self).__init__()
args = get_args() args = get_args()
# TODO: passing share_word_embeddings=False will not work correctly for T5 and embeddings will not be synced. Fix later for T5.
super(TransformerLanguageModel, self).__init__(share_word_embeddings=not args.untie_embeddings_and_output_weights)
self.pre_process = pre_process self.pre_process = pre_process
self.post_process = post_process self.post_process = post_process
...@@ -340,6 +341,7 @@ class TransformerLanguageModel(MegatronModule): ...@@ -340,6 +341,7 @@ class TransformerLanguageModel(MegatronModule):
self.decoder_attn_mask_type = decoder_attn_mask_type self.decoder_attn_mask_type = decoder_attn_mask_type
self.add_pooler = add_pooler self.add_pooler = add_pooler
self.encoder_hidden_state = None self.encoder_hidden_state = None
self.untie_embeddings_and_output_weights = args.untie_embeddings_and_output_weights
# Embeddings. # Embeddings.
if self.pre_process: if self.pre_process:
...@@ -408,6 +410,14 @@ class TransformerLanguageModel(MegatronModule): ...@@ -408,6 +410,14 @@ class TransformerLanguageModel(MegatronModule):
self.pooler = Pooler(self.hidden_size, self.init_method) self.pooler = Pooler(self.hidden_size, self.init_method)
self._pooler_key = 'pooler' self._pooler_key = 'pooler'
if self.untie_embeddings_and_output_weights:
self.output_layer = tensor_parallel.ColumnParallelLinear(
args.hidden_size,
args.padded_vocab_size,
bias=False, # Setting bias to False always to keep it consistent with embedding tying that also does not have a bias.
init_method=output_layer_init_method)
self._output_layer_key = 'output_layer'
def set_input_tensor(self, input_tensor): def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()""" """ See megatron.model.transformer.set_input_tensor()"""
...@@ -529,6 +539,10 @@ class TransformerLanguageModel(MegatronModule): ...@@ -529,6 +539,10 @@ class TransformerLanguageModel(MegatronModule):
state_dict_[self._pooler_key] \ state_dict_[self._pooler_key] \
= self.pooler.state_dict_for_save_checkpoint(prefix=prefix, = self.pooler.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars) keep_vars=keep_vars)
if self.untie_embeddings_and_output_weights:
state_dict_[self._output_layer_key] \
= self.output_layer.state_dict_for_save_checkpoint(prefix=prefix,
keep_vars=keep_vars)
if self.add_decoder: if self.add_decoder:
state_dict_[self._decoder_key] \ state_dict_[self._decoder_key] \
= self.decoder.state_dict_for_save_checkpoint(prefix=prefix, = self.decoder.state_dict_for_save_checkpoint(prefix=prefix,
...@@ -584,6 +598,11 @@ class TransformerLanguageModel(MegatronModule): ...@@ -584,6 +598,11 @@ class TransformerLanguageModel(MegatronModule):
'could not find data for pooler in the checkpoint' 'could not find data for pooler in the checkpoint'
self.pooler.load_state_dict(state_dict[self._pooler_key], self.pooler.load_state_dict(state_dict[self._pooler_key],
strict=strict) strict=strict)
if self.untie_embeddings_and_output_weights:
assert 'output_layer' in state_dict, \
'could not find data for output_layer in the checkpoint'
self.output_layer.load_state_dict(state_dict[self._output_layer_key],
strict=strict)
# Decoder. # Decoder.
if self.add_decoder: if self.add_decoder:
assert 'decoder' in state_dict, \ assert 'decoder' in state_dict, \
......
...@@ -71,7 +71,7 @@ class MegatronModule(torch.nn.Module): ...@@ -71,7 +71,7 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage() and \ if mpu.is_pipeline_last_stage() and \
not self.pre_process: not self.pre_process and not self.untie_embeddings_and_output_weights:
assert not mpu.is_pipeline_first_stage() assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first # set word_embeddings weights to 0 here, then copy first
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment