Commit 78cf869f authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Get PyTorch batched communication API working for interleaved schedule

parent 3cbf7547
...@@ -21,8 +21,7 @@ from megatron import get_args ...@@ -21,8 +21,7 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
use_ring_exchange=False):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -35,8 +34,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -35,8 +34,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
previous rank. previous rank.
recv_next: boolean for whether tensor should be received from recv_next: boolean for whether tensor should be received from
next rank. next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) (tensor_recv_prev, tensor_recv_next)
...@@ -76,34 +73,28 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -76,34 +73,28 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange: ops = []
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, if tensor_send_prev is not None:
tensor_recv_prev=tensor_recv_prev, send_prev_op = torch.distributed.P2POp(
tensor_send_next=tensor_send_next, torch.distributed.isend, tensor_send_prev,
tensor_recv_next=tensor_recv_next, mpu.get_pipeline_model_parallel_prev_rank())
group=mpu.get_pipeline_model_parallel_group()) ops.append(send_prev_op)
else: if tensor_recv_prev is not None:
ops = [] recv_prev_op = torch.distributed.P2POp(
if tensor_send_prev is not None: torch.distributed.irecv, tensor_recv_prev,
send_prev_op = torch.distributed.P2POp( mpu.get_pipeline_model_parallel_prev_rank())
torch.distributed.isend, tensor_send_prev, ops.append(recv_prev_op)
mpu.get_pipeline_model_parallel_prev_rank()) if tensor_send_next is not None:
ops.append(send_prev_op) send_next_op = torch.distributed.P2POp(
if tensor_recv_prev is not None: torch.distributed.isend, tensor_send_next,
recv_prev_op = torch.distributed.P2POp( mpu.get_pipeline_model_parallel_next_rank())
torch.distributed.irecv, tensor_recv_prev, ops.append(send_next_op)
mpu.get_pipeline_model_parallel_prev_rank()) if tensor_recv_next is not None:
ops.append(recv_prev_op) recv_next_op = torch.distributed.P2POp(
if tensor_send_next is not None: torch.distributed.irecv, tensor_recv_next,
send_next_op = torch.distributed.P2POp( mpu.get_pipeline_model_parallel_next_rank())
torch.distributed.isend, tensor_send_next, ops.append(recv_next_op)
mpu.get_pipeline_model_parallel_next_rank()) if len(ops) > 0:
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops) reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
...@@ -123,7 +114,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -123,7 +114,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None, use_ring_exchange=False): def recv_forward(timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -134,14 +125,13 @@ def recv_forward(timers=None, use_ring_exchange=False): ...@@ -134,14 +125,13 @@ def recv_forward(timers=None, use_ring_exchange=False):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
def recv_backward(timers=None, use_ring_exchange=False): def recv_backward(timers=None):
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -152,14 +142,13 @@ def recv_backward(timers=None, use_ring_exchange=False): ...@@ -152,14 +142,13 @@ def recv_backward(timers=None, use_ring_exchange=False):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('backward-recv').stop() timers('backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, timers=None, use_ring_exchange=False): def send_forward(output_tensor, timers=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
if timers is not None: if timers is not None:
...@@ -168,13 +157,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False): ...@@ -168,13 +157,12 @@ def send_forward(output_tensor, timers=None, use_ring_exchange=False):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward(input_tensor_grad, timers=None):
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if not mpu.is_pipeline_first_stage(): if not mpu.is_pipeline_first_stage():
if timers is not None: if timers is not None:
...@@ -183,13 +171,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False): ...@@ -183,13 +171,12 @@ def send_backward(input_tensor_grad, timers=None, use_ring_exchange=False):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('backward-send').stop() timers('backward-send').stop()
def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=False): def send_forward_recv_backward(output_tensor, timers=None):
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -200,14 +187,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal ...@@ -200,14 +187,13 @@ def send_forward_recv_backward(output_tensor, timers=None, use_ring_exchange=Fal
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('forward-send-backward-recv').stop() timers('forward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange=False): def send_backward_recv_forward(input_tensor_grad, timers=None):
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -218,8 +204,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange ...@@ -218,8 +204,7 @@ def send_backward_recv_forward(input_tensor_grad, timers=None, use_ring_exchange
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False)
use_ring_exchange=use_ring_exchange)
if timers is not None: if timers is not None:
timers('backward-send-forward-recv').stop() timers('backward-send-forward-recv').stop()
return input_tensor return input_tensor
...@@ -233,8 +218,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None): ...@@ -233,8 +218,7 @@ def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False, recv_next=False)
use_ring_exchange=True)
if timers is not None: if timers is not None:
timers('forward-send-forward-recv').stop() timers('forward-send-forward-recv').stop()
return input_tensor return input_tensor
...@@ -248,8 +232,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None): ...@@ -248,8 +232,7 @@ def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=False, recv_prev=False,
recv_next=recv_next, recv_next=recv_next)
use_ring_exchange=True)
if timers is not None: if timers is not None:
timers('backward-send-backward-recv').stop() timers('backward-send-backward-recv').stop()
return output_tensor_grad return output_tensor_grad
...@@ -265,8 +248,7 @@ def send_forward_backward_recv_forward_backward( ...@@ -265,8 +248,7 @@ def send_forward_backward_recv_forward_backward(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad, tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next)
use_ring_exchange=True)
if timers is not None: if timers is not None:
timers('forward-backward-send-forward-backward-recv').stop() timers('forward-backward-send-forward-backward-recv').stop()
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -210,7 +210,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append( input_tensors[0].append(
p2p_communication.recv_forward(timers, use_ring_exchange=True)) p2p_communication.recv_forward(timers))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
...@@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -322,7 +322,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
p2p_communication.recv_backward(timers, use_ring_exchange=True)) p2p_communication.recv_backward(timers))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment