Commit 6a680986 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'bugfix' into 'main'

Use timers kwargs correctly to prevent bug with new p2p_communication API

See merge request ADLR/megatron-lm!295
parents a676bc2d e727de99
......@@ -231,7 +231,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(
p2p_communication.recv_forward(timers))
p2p_communication.recv_forward(timers=timers))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
......@@ -265,7 +265,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers)
output_tensor, recv_prev=recv_prev, timers=timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
......@@ -343,7 +343,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers))
p2p_communication.recv_backward(timers=timers))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......@@ -355,7 +355,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers))
input_tensor_grad, recv_next=recv_next, timers=timers))
return losses_reduced
......@@ -389,10 +389,10 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = p2p_communication.recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers=timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
p2p_communication.send_forward(output_tensor, timers)
p2p_communication.send_forward(output_tensor, timers=timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
......@@ -401,7 +401,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = p2p_communication.recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers=timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
......@@ -410,11 +410,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
p2p_communication.send_forward(output_tensor, timers)
p2p_communication.send_forward(output_tensor, timers=timers)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor,
timers)
timers=timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
......@@ -423,7 +423,7 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if forward_only:
if not last_iteration:
input_tensor = p2p_communication.recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers=timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
......@@ -433,11 +433,11 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
if last_iteration:
input_tensor = None
p2p_communication.send_backward(input_tensor_grad, timers)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers)
input_tensor_grad, timers=timers)
# Run cooldown backward passes.
if not forward_only:
......@@ -445,12 +445,12 @@ def forward_backward_pipelining_without_interleaving(forward_step_func, data_ite
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = p2p_communication.recv_backward(timers)
output_tensor_grad = p2p_communication.recv_backward(timers=timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
p2p_communication.send_backward(input_tensor_grad, timers)
p2p_communication.send_backward(input_tensor_grad, timers=timers)
return losses_reduced
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment