Commit 788b59e7 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'remove_ring_exchange' into 'main'

Get PyTorch batched communication API working for interleaved schedule

See merge request ADLR/megatron-lm!242
parents 1acac4e3 cbad126a
...@@ -104,9 +104,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -104,9 +104,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch.distributed.irecv, tensor_recv_next, torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank()) mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op) ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops) if len(ops) > 0:
for req in reqs: reqs = torch.distributed.batch_isend_irecv(ops)
req.wait() for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -123,7 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -123,7 +124,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None, use_ring_exchange=False): def recv_forward(timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -134,14 +135,13 @@ def recv_forward(timers=None, use_ring_exchange=False): ...@@ -134,14 +135,13 @@ def recv_forward(timers=None, use_ring_exchange=False):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
def recv_backward(timers=None, use_ring_exchange=False): def recv_backward(timers=None):
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -152,14 +152,13 @@ def recv_backward(timers=None, use_ring_exchange=False): ...@@ -152,14 +152,13 @@ def recv_backward(timers=None, use_ring_exchange=False):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('backward-recv').stop() timers('backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, timers=None, use_ring_exchange=False): def send_forward(output_tensor, timers=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
if timers is not None: if timers is not None:
...@@ -168,13 +167,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False): ...@@ -168,13 +167,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward(input_tensor_grad, timers=None):
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
...@@ -183,13 +181,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): ...@@ -183,13 +181,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('backward-send').stop() timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False): def send_forward_recv_backward(output_tensor, timers=None):
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -200,14 +197,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal ...@@ -200,14 +197,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('forward-send-backward-recv').stop() timers('forward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward_recv_forward(input_tensor_grad, timers=None):
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -218,8 +214,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange ...@@ -218,8 +214,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('backward-send-forward-recv').stop() timers('backward-send-forward-recv').stop()
return input_tensor return input_tensor
...@@ -233,8 +228,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): ...@@ -233,8 +228,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False, recv_next=False)
use_ring_exchange=True)
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').stop() timers('forward-send-forward-recv').stop()
return input_tensor return input_tensor
...@@ -248,8 +242,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -248,8 +242,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=recv_next, recv_next=recv_next)
use_ring_exchange=True)
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').stop() timers('backward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
...@@ -265,8 +258,7 @@ def send_forward_backward_recv_forward_backward( ...@@ -265,8 +258,7 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next)
use_ring_exchange=True)
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop() timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(timers, use_ring_exchange=True)) p2p_communication.recv_forward(timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
...@@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers, use_ring_exchange=True)) p2p_communication.recv_backward(timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment