Commit d899988e authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Use batched send and recv instead of torch.distributed.ring_exchange()

parent 2348c99a
......@@ -64,10 +64,6 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
if args.pipeline_model_parallel_size > 1:
if "ring_exchange" not in dir(torch.distributed):
raise Exception('PyTorch with torch.distributed.ring_exchange '
'needed to run pipeline MP!')
# Checks.
model_parallel_size = args.pipeline_model_parallel_size * \
args.tensor_model_parallel_size
......
......@@ -36,6 +36,8 @@ from .initialize import is_pipeline_first_stage, is_pipeline_last_stage
from .initialize import get_tensor_model_parallel_src_rank
from .initialize import get_pipeline_model_parallel_first_rank
from .initialize import get_pipeline_model_parallel_last_rank
from .initialize import get_pipeline_model_parallel_next_rank
from .initialize import get_pipeline_model_parallel_prev_rank
from .initialize import get_tensor_model_parallel_world_size, set_tensor_model_parallel_world_size
from .initialize import get_pipeline_model_parallel_world_size, set_pipeline_model_parallel_world_size
from .initialize import initialize_model_parallel
......
......@@ -276,16 +276,30 @@ def get_tensor_model_parallel_src_rank():
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_first_rank():
def get_pipeline_model_parallel_next_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, \
"Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
......
......@@ -325,7 +325,7 @@ def setup_model_and_optimizer(model_provider_func):
def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward):
"""Communicate tensors between stages using torch.distributed.ring_exchange(.) API."""
"""Communicate tensors between stages."""
args = get_args()
# Create placeholder tensors for receive in forward and backward directions
......@@ -348,11 +348,26 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
dtype=dtype)
# Send tensors in both the forward and backward directions as appropriate.
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next,
tensor_recv_next=tensor_recv_next,
group=mpu.get_pipeline_model_parallel_group())
ops = []
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_prev,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(torch.distributed.isend, tensor_send_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(torch.distributed.irecv, tensor_recv_next,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
return tensor_recv_prev, tensor_recv_next
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment