Commit 6a680986 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'bugfix' into 'main'

Use timers kwargs correctly to prevent bug with new p2p_communication API

See merge request ADLR/megatron-lm!295
parents a676bc2d e727de99
...@@ -231,7 +231,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -231,7 +231,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(timers)) p2p_communication.recv_forward(timers=timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
...@@ -265,7 +265,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -265,7 +265,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_forward_recv_forward( p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers) output_tensor, recv_prev=recv_prev, timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -343,7 +343,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -343,7 +343,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers)) p2p_communication.recv_backward(timers=timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
...@@ -355,7 +355,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -355,7 +355,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers)) input_tensor_grad, recv_next=recv_next, timers=timers))
return losses_reduced return losses_reduced
...@@ -389,10 +389,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -389,10 +389,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers=timers)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
...@@ -401,7 +401,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -401,7 +401,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers=timers)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -410,11 +410,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -410,11 +410,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
if forward_only: if forward_only:
p2p_communication.send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers=timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, p2p_communication.send_forward_recv_backward(output_tensor,
timers) timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the # Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass. # start of the list for backward pass.
...@@ -423,7 +423,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -423,7 +423,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if forward_only: if forward_only:
if not last_iteration: if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers=timers)
else: else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
...@@ -433,11 +433,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -433,11 +433,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers=timers)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_backward_recv_forward( p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers) input_tensor_grad, timers=timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -445,12 +445,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite ...@@ -445,12 +445,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers) output_tensor_grad = p2p_communication.recv_backward(timers=timers)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers=timers)
return losses_reduced return losses_reduced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment