Commit a6756bf8 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Better 'forward' and 'backward' timing in megatron/training.py

parent 3e6898e6
...@@ -363,6 +363,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -363,6 +363,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
input_tensors, output_tensors, input_tensors, output_tensors,
losses_reduced, timers): losses_reduced, timers):
# Forward model for one step. # Forward model for one step.
timers('forward').start()
timers('forward-compute').start() timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor) output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop() timers('forward-compute').stop()
...@@ -374,14 +375,13 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -374,14 +375,13 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
losses_reduced.append(loss_reduced) losses_reduced.append(loss_reduced)
else: else:
timers('forward-send').start() timers('forward-send').start()
timers('backward-recv').start()
_, output_tensor_grad = communicate( _, output_tensor_grad = communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_forward=False, recv_forward=False,
recv_backward=True) recv_backward=True)
timers('forward-send').stop() timers('forward-send').stop()
timers('backward-recv').stop() timers('forward').stop()
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
...@@ -390,6 +390,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -390,6 +390,7 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
# Backward pass for one step. # Backward pass for one step.
timers('backward').start()
timers('backward-compute').start() timers('backward-compute').start()
input_grad_tensor = \ input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
...@@ -397,16 +398,15 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -397,16 +398,15 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
timers('backward-send').start() timers('backward-send').start()
timers('forward-recv').start()
input_tensor, _ = communicate( input_tensor, _ = communicate(
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_grad_tensor, tensor_send_prev=input_grad_tensor,
recv_forward=(not last_microbatch), recv_forward=(not last_microbatch),
recv_backward=False) recv_backward=False)
timers('backward-send').stop() timers('backward-send').stop()
timers('forward-recv').stop()
else: else:
input_tensor = None input_tensor = None
timers('backward').stop()
return input_tensor return input_tensor
...@@ -460,10 +460,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -460,10 +460,12 @@ def train_step(forward_step_func, data_iterator,
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
timers('forward-recv').start()
input_tensor, _ = communicate(tensor_send_next=None, input_tensor, _ = communicate(tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_forward=True, recv_forward=True,
recv_backward=False) recv_backward=False)
timers('forward-recv').stop()
# Run 1F1B. # Run 1F1B.
for i in range(num_microbatches_in_minibatch - num_warmup_microbatches): for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment