Commit 8df49e72 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'add_ring_exchange_flag' into 'main'

added a flag to be able to switch between pytorch and ring exchange p2p

See merge request ADLR/megatron-lm!434
parents c464a103 e36cdd79
...@@ -761,6 +761,10 @@ def _add_distributed_args(parser): ...@@ -761,6 +761,10 @@ def _add_distributed_args(parser):
group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false', group.add_argument('--no-scatter-gather-tensors-in-pipeline', action='store_false',
help='Use scatter/gather to optimize communication of tensors in pipeline', help='Use scatter/gather to optimize communication of tensors in pipeline',
dest='scatter_gather_tensors_in_pipeline') dest='scatter_gather_tensors_in_pipeline')
group.add_argument('--use-ring-exchange-p2p', action='store_true',
default=False, help='If set, use custom-built ring exchange '
'for p2p communications. Note that this option will require '
'a custom built image that support ring-exchange p2p.')
group.add_argument('--local_rank', type=int, default=None, group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.') help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False, group.add_argument('--lazy-mpu-init', type=bool, required=False,
......
...@@ -23,7 +23,6 @@ from megatron import mpu ...@@ -23,7 +23,6 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape, tensor_shape,
use_ring_exchange=False,
dtype_=None): dtype_=None):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -40,8 +39,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -40,8 +39,6 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape: shape of tensor to receive (this method assumes that all tensor_shape: shape of tensor to receive (this method assumes that all
tensors sent and received in a single function call are tensors sent and received in a single function call are
the same shape). the same shape).
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
dtype_: optional, this is used when the tensor that needs to be dtype_: optional, this is used when the tensor that needs to be
communicated is different from args.params_dtype. communicated is different from args.params_dtype.
Returns: Returns:
...@@ -103,7 +100,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -103,7 +100,7 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev) tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange: if args.use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
tensor_recv_prev=tensor_recv_prev, tensor_recv_prev=tensor_recv_prev,
tensor_send_next=tensor_send_next, tensor_send_next=tensor_send_next,
...@@ -135,8 +132,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -135,8 +132,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
reqs = torch.distributed.batch_isend_irecv(ops) reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and \ if not override_scatter_gather_tensors_in_pipeline and \
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment