Commit 63740223 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Improve time logging when num_microbatches_in_minibatch > 1

Make sure all forward and backward operations are accounted for
parent d5b526d5
......@@ -304,9 +304,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensor = None
# Forward model for one step.
timers('forward').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
......@@ -338,11 +336,8 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten
recv_backward=True)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward').stop()
if not mpu.is_pipeline_first_stage():
communicate(
......@@ -381,22 +376,16 @@ def train_step(forward_step_func, data_iterator,
losses_reduced = []
# Run warmup forward passes.
timers('forward').start()
for i in range(num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
# Run 1F1B.
for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
timers('forward').stop()
# Run cooldown backward passes.
timers('backward').start()
for i in range(num_warmup_microbatches):
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
......@@ -415,6 +404,7 @@ def train_step(forward_step_func, data_iterator,
timers('backward-master-grad').stop()
# All-reduce across first and last stages.
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
args.pipeline_model_parallel_size > 1:
unwrapped_model = model
......@@ -424,6 +414,7 @@ def train_step(forward_step_func, data_iterator,
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start()
......@@ -440,6 +431,7 @@ def train_step(forward_step_func, data_iterator,
else:
optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop()
timers('backward').stop()
# Update parameters.
timers('optimizer').start()
......@@ -503,6 +495,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-backward')
add_to_logging('backward-allreduce')
add_to_logging('backward-master-grad')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-clip-grad')
add_to_logging('optimizer')
add_to_logging('batch generator')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment