Commit 63740223 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Improve time logging when num_microbatches_in_minibatch > 1

Make sure all forward and backward operations are accounted for
parent d5b526d5
...@@ -304,9 +304,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model, ...@@ -304,9 +304,7 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensor = None input_tensor = None
# Forward model for one step. # Forward model for one step.
timers('forward').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor) output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
...@@ -338,11 +336,8 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten ...@@ -338,11 +336,8 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten
recv_backward=True) recv_backward=True)
# Backward pass for one step. # Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \ input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward').stop()
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
communicate( communicate(
...@@ -381,22 +376,16 @@ def train_step(forward_step_func, data_iterator, ...@@ -381,22 +376,16 @@ def train_step(forward_step_func, data_iterator,
losses_reduced = [] losses_reduced = []
# Run warmup forward passes. # Run warmup forward passes.
timers('forward').start()
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
forward_step_with_communication( forward_step_with_communication(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
input_tensors, output_tensors, input_tensors, output_tensors,
losses_reduced, timers) losses_reduced, timers)
timers('forward').stop()
# Run 1F1B.
for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
# Run cooldown backward passes. # Run cooldown backward passes.
timers('backward').start()
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
backward_step_with_communication( backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers) optimizer, model, input_tensors, output_tensors, timers)
...@@ -415,6 +404,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -415,6 +404,7 @@ def train_step(forward_step_func, data_iterator,
timers('backward-master-grad').stop() timers('backward-master-grad').stop()
# All-reduce across first and last stages. # All-reduce across first and last stages.
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \ if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
args.pipeline_model_parallel_size > 1: args.pipeline_model_parallel_size > 1:
unwrapped_model = model unwrapped_model = model
...@@ -424,6 +414,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -424,6 +414,7 @@ def train_step(forward_step_func, data_iterator,
word_embeddings_weight = unwrapped_model.word_embeddings_weight() word_embeddings_weight = unwrapped_model.word_embeddings_weight()
torch.distributed.all_reduce(word_embeddings_weight.grad, torch.distributed.all_reduce(word_embeddings_weight.grad,
group=mpu.get_embedding_group()) group=mpu.get_embedding_group())
timers('backward-embedding-all-reduce').stop()
# Clipping gradients helps prevent the exploding gradient. # Clipping gradients helps prevent the exploding gradient.
timers('backward-clip-grad').start() timers('backward-clip-grad').start()
...@@ -440,6 +431,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -440,6 +431,7 @@ def train_step(forward_step_func, data_iterator,
else: else:
optimizer.clip_master_grads(args.clip_grad) optimizer.clip_master_grads(args.clip_grad)
timers('backward-clip-grad').stop() timers('backward-clip-grad').stop()
timers('backward').stop()
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
...@@ -503,6 +495,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -503,6 +495,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging('backward-backward') add_to_logging('backward-backward')
add_to_logging('backward-allreduce') add_to_logging('backward-allreduce')
add_to_logging('backward-master-grad') add_to_logging('backward-master-grad')
add_to_logging('backward-embedding-all-reduce')
add_to_logging('backward-clip-grad') add_to_logging('backward-clip-grad')
add_to_logging('optimizer') add_to_logging('optimizer')
add_to_logging('batch generator') add_to_logging('batch generator')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment