Commit 57c3b364 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Refactor word_embeddings_weight() logic into separate method, and other Mohammad comments

parent eed0062a
......@@ -144,7 +144,7 @@ def _set_random_seed(seed_):
"""Set random seed for reproducability."""
if seed_ is not None and seed_ > 0:
# Ensure that different pipeline MP stages get different seeds.
seed = seed_ + mpu.get_pipeline_model_parallel_rank()
seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
......
......@@ -26,7 +26,7 @@ from megatron.model.utils import openai_gelu, erf_gelu
from megatron.model.utils import get_linear_layer
from megatron.model.utils import init_method_normal
from megatron.model.utils import scaled_init_method_normal
from megatron.module import MegatronModule
from megatron.module import MegatronModule, PipelinedMegatronModule
def bert_attention_mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
......@@ -126,7 +126,7 @@ def post_language_model_processing(lm_output, pooled_output,
return lm_loss, binary_logits
class BertModelBase(MegatronModule):
class BertModelBase(PipelinedMegatronModule):
"""Bert Language model."""
def __init__(self, num_tokentypes=2, add_binary_head=True,
......@@ -149,28 +149,7 @@ class BertModelBase(MegatronModule):
init_method=init_method,
scaled_init_method=scaled_init_method)
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
# below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
self.lm_head = BertLMHead(
self.word_embeddings_weight().size(0),
args.hidden_size, init_method, args.layernorm_epsilon, parallel_output)
......@@ -180,18 +159,8 @@ class BertModelBase(MegatronModule):
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage():
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
self.initialize_word_embeddings(init_method_normal)
def forward(self, bert_model_input, attention_mask,
tokentype_ids=None, lm_labels=None):
......
......@@ -19,7 +19,7 @@ import torch
from megatron import get_args
from megatron import mpu
from megatron.module import MegatronModule
from megatron.module import PipelinedMegatronModule
from .language_model import parallel_lm_logits
from .language_model import get_language_model
......@@ -61,7 +61,7 @@ def post_language_model_processing(lm_output, labels, logit_weights,
return loss
class GPT2ModelBase(MegatronModule):
class GPT2ModelBase(PipelinedMegatronModule):
"""GPT-2 Language model."""
def __init__(self, num_tokentypes=0, parallel_output=True):
......@@ -79,39 +79,7 @@ class GPT2ModelBase(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
# below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage():
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
self.initialize_word_embeddings(init_method_normal)
def forward(self, gpt2_model_input, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
......
......@@ -505,6 +505,8 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers.
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
self.num_layers = args.num_layers // args.pipeline_model_parallel_size
# Transformer layers.
......
......@@ -17,9 +17,12 @@
import torch
from megatron import get_args
from megatron import mpu
class MegatronModule(torch.nn.Module):
"""Megatron specific extentions of torch Module."""
"""Megatron specific extensions of torch Module."""
def __init__(self):
super(MegatronModule, self).__init__()
......@@ -29,3 +32,46 @@ class MegatronModule(torch.nn.Module):
"""Use this function to override the state dict for
saving checkpoints."""
return self.state_dict(destination, prefix, keep_vars)
class PipelinedMegatronModule(MegatronModule):
"""Pipelining specific extensions of MegatronModule."""
def __init__(self):
super(PipelinedMegatronModule, self).__init__()
def word_embeddings_weight(self):
if mpu.is_pipeline_first_stage():
return self.language_model.embedding.word_embeddings.weight
if mpu.is_pipeline_last_stage():
return self.word_embeddings.weight
raise Exception('word_embeddings_weight() should be '
'called for first and last stage only')
def initialize_word_embeddings(self, init_method_normal):
args = get_args()
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
# If first and last stages are different, set word_embeddings
# weights to 0 here, then copy first stage's weights using all_reduce
# below.
self.word_embeddings = mpu.VocabParallelEmbedding(
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment