Commit 7be2648a authored by Jared Casper's avatar Jared Casper
Browse files

Clarify module.initialize_word_embeddings.

parent c4c68dce
...@@ -60,8 +60,13 @@ class MegatronModule(torch.nn.Module): ...@@ -60,8 +60,13 @@ class MegatronModule(torch.nn.Module):
if not self.share_word_embeddings: if not self.share_word_embeddings:
raise Exception('initialize_word_embeddings() was called but ' raise Exception('initialize_word_embeddings() was called but '
'share_word_embeddings is false') 'share_word_embeddings is false')
# This function just initializes the word embeddings in the final stage
# when we are using pipeline parallelism. If we aren't using pipeline
# parallelism there is nothing to do.
if args.pipeline_model_parallel_size == 1: if args.pipeline_model_parallel_size == 1:
return return
# Parameters are shared between the word embeddings layer, and the # Parameters are shared between the word embeddings layer, and the
# heads at the end of the model. In a pipelined setup with more than # heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different # one stage, the initial embedding layer and the head are on different
...@@ -75,16 +80,16 @@ class MegatronModule(torch.nn.Module): ...@@ -75,16 +80,16 @@ class MegatronModule(torch.nn.Module):
# the two word_embeddings layers to ensure that every applied weight # the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages. # update is the same on both stages.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage(): assert not mpu.is_pipeline_first_stage()
self._word_embeddings_for_head_key = 'word_embeddings_for_head' self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings # set word_embeddings weights to 0 here, then copy first
# weights to 0 here, then copy first stage's weights using # stage's weights using all_reduce below.
# all_reduce below. self.word_embeddings = mpu.VocabParallelEmbedding(
self.word_embeddings = mpu.VocabParallelEmbedding( args.padded_vocab_size, args.hidden_size,
args.padded_vocab_size, args.hidden_size, init_method=init_method_normal(args.init_method_std))
init_method=init_method_normal(args.init_method_std)) self.word_embeddings.weight.data.fill_(0)
self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True
self.word_embeddings.weight.shared = True
# Ensure that first and last stages have the same initial parameter # Ensure that first and last stages have the same initial parameter
# values. # values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage(): if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment