Commit eed0062a authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Log times for various sub-operations in forward and backward pass in main training loop

parent 2d8de296
......@@ -293,27 +293,33 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
if not mpu.is_pipeline_first_stage():
timers('forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
timers('forward-recv').stop()
else:
input_tensor = None
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
timers('forward-send').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
......@@ -327,22 +333,28 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
timers('backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('backward-recv').stop()
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send').start()
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
timers('backward-send').stop()
def train_step(forward_step_func, data_iterator,
......@@ -385,12 +397,14 @@ def train_step(forward_step_func, data_iterator,
input_tensors, output_tensors,
losses_reduced, timers)
else:
timers('forward-compute').start()
input_tensor = None
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor)
output_tensor = loss
losses_reduced.append(loss_reduced)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
timers('forward-compute').stop()
timers('forward').stop()
# Run cooldown backward passes.
......@@ -400,10 +414,12 @@ def train_step(forward_step_func, data_iterator,
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
else:
timers('backward-compute').start()
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = None
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
# All-reduce if needed.
if args.DDP_impl == 'local':
......@@ -509,7 +525,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if name in timers.timers:
timers_to_log.append(name)
add_to_logging('forward')
add_to_logging('forward-compute')
add_to_logging('forward-recv')
add_to_logging('forward-send')
add_to_logging('backward')
add_to_logging('backward-compute')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-master-grad')
add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment