Commit 8fb2bc8c authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Clarifications in comments and minor refactoring to make main training loop more readable

parent 1271fd73
......@@ -577,7 +577,6 @@ class FP16_Optimizer(object):
# a loss scale that works. After you find a loss scale that works, do a final dummy
# backward pass with retain_graph=False to tear down the graph. Doing this would avoid
# discarding the iteration, but probably wouldn't improve overall efficiency.
# Convert output_tensor to float if it's the loss, otherwise stay in half precision.
self.loss_scaler.backward(output_tensor, retain_graph=retain_graph,
output_tensor_grad=output_tensor_grad)
if update_master_grads:
......
......@@ -149,6 +149,17 @@ class BertModelBase(MegatronModule):
init_method=init_method,
scaled_init_method=scaled_init_method)
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
......@@ -169,8 +180,7 @@ class BertModelBase(MegatronModule):
self.binary_head = get_linear_layer(args.hidden_size, 2,
init_method)
self._binary_head_key = 'binary_head'
# Ensure that first and last stages have the same initial embedding weights.
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
......
......@@ -79,6 +79,17 @@ class GPT2ModelBase(MegatronModule):
scaled_init_method=scaled_init_method_normal(args.init_method_std,
args.num_layers))
# Parameters are shared between the word embeddings layer, and the heads at
# the end of the model. In a pipelined setup with more than one stage, the
# initial embedding layer and the head are on different workers, so we do
# the following:
# 1. Create a second copy of word_embeddings on the last stage, with initial
# parameters of 0.0.
# 2. Do an all-reduce between the first and last stage to ensure that the
# two copies of word_embeddings start off with the same parameter values.
# 3. In the training loop, before an all-reduce between the grads of the two
# word_embeddings layers to ensure that every applied weight update is the
# same on both stages.
if mpu.is_pipeline_last_stage():
if not mpu.is_pipeline_first_stage():
self._word_embeddings_for_head_key = 'word_embeddings_for_head'
......@@ -89,8 +100,7 @@ class GPT2ModelBase(MegatronModule):
args.padded_vocab_size, args.hidden_size,
init_method=init_method_normal(args.init_method_std))
self.word_embeddings.weight.data.fill_(0)
# Ensure that first and last stages have the same initial embedding weights.
# Ensure that first and last stages have the same initial parameter values.
if mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage():
torch.distributed.all_reduce(self.word_embeddings_weight().data,
group=mpu.get_embedding_group())
......
......@@ -18,8 +18,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.tensor_model_parallel_size == 1, \
assert args.tensor_model_parallel_size == 1 and args.pipeline_model_parallel_size == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
......
......@@ -361,6 +361,9 @@ def train_step(forward_step_func, data_iterator,
# Compute number of microbatches in a minibatch.
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
# For now, perform training without warmup. Perform forward
# passes for all microbatches, then backward passes for all
# microbatches.
# TODO: Switch to the following schedule to facilitate more
# memory-efficient training.
# num_warmup_microbatches = \
......@@ -369,9 +372,6 @@ def train_step(forward_step_func, data_iterator,
# num_warmup_microbatches = min(
# num_warmup_microbatches,
# num_microbatches_in_minibatch)
# For now, perform training without warmup. Perform forward
# passes for all microbatches, then backward passes for all
# microbatches.
num_warmup_microbatches = num_microbatches_in_minibatch
input_tensors = []
......@@ -381,17 +381,31 @@ def train_step(forward_step_func, data_iterator,
# Run warmup forward passes.
timers('forward').start()
for i in range(num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
if args.pipeline_model_parallel_size > 1:
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
else:
input_tensor = None
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor)
output_tensor = loss
losses_reduced.append(loss_reduced)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
timers('forward').stop()
# Run cooldown backward passes.
timers('backward').start()
for i in range(num_warmup_microbatches):
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
if args.pipeline_model_parallel_size > 1:
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
else:
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = None
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
# All-reduce if needed.
if args.DDP_impl == 'local':
......@@ -400,7 +414,10 @@ def train_step(forward_step_func, data_iterator,
fp32_allreduce=args.fp32_allreduce)
timers('allreduce').stop()
# All-reduce across first and last stages.
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
args.pipeline_model_parallel_size > 1:
......
......@@ -32,7 +32,6 @@ from megatron.data.realm_dataset_utils import get_ict_batch
def pretrain_ict_model_provider():
args = get_args()
assert args.pipeline_model_parallel_size == 1, 'pipeline_model_parallel_size must be 1!'
return general_ict_model_provider(False, False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment