Commit 1dccefd8 authored by Mostofa Patwary's avatar Mostofa Patwary Committed by Deepak Narayanan
Browse files

Make it possible to pass in tensor shapes to communication methods in p2p_communication.py

parent 3db6517a
......@@ -22,7 +22,9 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False):
use_ring_exchange=False, tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py.
......@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
......@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
if tensor_shape is None:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size()
else:
......@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
requires_grad=requires_grad,
device=torch.cuda.current_device(),
dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline:
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
......@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline:
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
......@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None):
def recv_forward(tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
......@@ -135,7 +157,11 @@ def recv_forward(timers=None):
tensor_send_next=None,
tensor_send_prev=None,
recv_prev=True,
recv_next=False)
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None:
timers('forward-recv').stop()
return input_tensor
......@@ -158,8 +184,11 @@ def recv_backward(timers=None):
return output_tensor_grad
def send_forward(output_tensor, timers=None):
def send_forward(output_tensor, timers=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage():
if timers is not None:
timers('forward-send').start()
......@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_prev=False,
recv_next=False)
recv_next=False,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None:
timers('forward-send').stop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment