Commit eed0062a authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Log times for various sub-operations in forward and backward pass in main training loop

parent 2d8de296
...@@ -293,27 +293,33 @@ def forward_step_with_communication(forward_step_func, data_iterator, model, ...@@ -293,27 +293,33 @@ def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors, input_tensors, output_tensors,
losses_reduced, timers): losses_reduced, timers):
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
timers('forward-recv').start()
input_tensor, _ = communicate( input_tensor, _ = communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_forward=True, recv_forward=True,
recv_backward=False) recv_backward=False)
timers('forward-recv').stop()
else: else:
input_tensor = None input_tensor = None
# Forward model for one step. # Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor) output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor loss, loss_reduced = output_tensor
output_tensor = loss output_tensor = loss
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
else: else:
timers('forward-send').start()
communicate( communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_forward=False, recv_forward=False,
recv_backward=False) recv_backward=False)
timers('forward-send').stop()
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
...@@ -327,22 +333,28 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten ...@@ -327,22 +333,28 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
else: else:
timers('backward-recv').start()
_, output_tensor_grad = communicate( _, output_tensor_grad = communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_forward=False, recv_forward=False,
recv_backward=True) recv_backward=True)
timers('backward-recv').stop()
# Backward pass for one step. # Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \ input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
timers('backward-send').start()
communicate( communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_grad_tensor, tensor_send_prev=input_grad_tensor,
recv_forward=False, recv_forward=False,
recv_backward=False) recv_backward=False)
timers('backward-send').stop()
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
...@@ -385,12 +397,14 @@ def train_step(forward_step_func, data_iterator, ...@@ -385,12 +397,14 @@ def train_step(forward_step_func, data_iterator,
input_tensors, output_tensors, input_tensors, output_tensors,
losses_reduced, timers) losses_reduced, timers)
else: else:
timers('forward-compute').start()
input_tensor = None input_tensor = None
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor) loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor)
output_tensor = loss output_tensor = loss
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
timers('forward-compute').stop()
timers('forward').stop() timers('forward').stop()
# Run cooldown backward passes. # Run cooldown backward passes.
...@@ -400,10 +414,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -400,10 +414,12 @@ def train_step(forward_step_func, data_iterator,
backward_step_with_communication( backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers) optimizer, model, input_tensors, output_tensors, timers)
else: else:
timers('backward-compute').start()
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = None output_tensor_grad = None
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
...@@ -509,7 +525,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, ...@@ -509,7 +525,13 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
if name in timers.timers: if name in timers.timers:
timers_to_log.append(name) timers_to_log.append(name)
add_to_logging('forward') add_to_logging('forward')
add_to_logging('forward-compute')
add_to_logging('forward-recv')
add_to_logging('forward-send')
add_to_logging('backward') add_to_logging('backward')
add_to_logging('backward-compute')
add_to_logging('backward-recv')
add_to_logging('backward-send')
add_to_logging('backward-master-grad') add_to_logging('backward-master-grad')
add_to_logging('backward-params-all-reduce') add_to_logging('backward-params-all-reduce')
add_to_logging('backward-embedding-all-reduce') add_to_logging('backward-embedding-all-reduce')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment