Commit cbad126a authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Bring back call to ring_exchange() in _communicate()

parent 78cf869f
...@@ -21,7 +21,8 @@ from megatron import get_args ...@@ -21,7 +21,8 @@ from megatron import get_args
from megatron import mpu from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -34,6 +35,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -34,6 +35,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
previous rank. previous rank.
recv_next: boolean for whether tensor should be received from recv_next: boolean for whether tensor should be received from
next rank. next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) (tensor_recv_prev, tensor_recv_next)
...@@ -73,31 +76,38 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next): ...@@ -73,31 +76,38 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next):
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
ops = [] if use_ring_exchange:
if tensor_send_prev is not None: torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
send_prev_op = torch.distributed.P2POp( tensor_recv_prev=tensor_recv_prev,
torch.distributed.isend, tensor_send_prev, tensor_send_next=tensor_send_next,
mpu.get_pipeline_model_parallel_prev_rank()) tensor_recv_next=tensor_recv_next,
ops.append(send_prev_op) group=mpu.get_pipeline_model_parallel_group())
if tensor_recv_prev is not None: else:
recv_prev_op = torch.distributed.P2POp( ops = []
torch.distributed.irecv, tensor_recv_prev, if tensor_send_prev is not None:
mpu.get_pipeline_model_parallel_prev_rank()) send_prev_op = torch.distributed.P2POp(
ops.append(recv_prev_op) torch.distributed.isend, tensor_send_prev,
if tensor_send_next is not None: mpu.get_pipeline_model_parallel_prev_rank())
send_next_op = torch.distributed.P2POp( ops.append(send_prev_op)
torch.distributed.isend, tensor_send_next, if tensor_recv_prev is not None:
mpu.get_pipeline_model_parallel_next_rank()) recv_prev_op = torch.distributed.P2POp(
ops.append(send_next_op) torch.distributed.irecv, tensor_recv_prev,
if tensor_recv_next is not None: mpu.get_pipeline_model_parallel_prev_rank())
recv_next_op = torch.distributed.P2POp( ops.append(recv_prev_op)
torch.distributed.irecv, tensor_recv_next, if tensor_send_next is not None:
mpu.get_pipeline_model_parallel_next_rank()) send_next_op = torch.distributed.P2POp(
ops.append(recv_next_op) torch.distributed.isend, tensor_send_next,
if len(ops) > 0: mpu.get_pipeline_model_parallel_next_rank())
reqs = torch.distributed.batch_isend_irecv(ops) ops.append(send_next_op)
for req in reqs: if tensor_recv_next is not None:
req.wait() recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment