"vscode:/vscode.git/clone" did not exist on "2a8cc8d1138837309606ebd2ed22008c80e9e1fb"
Commit 1dccefd8 authored by Mostofa Patwary's avatar Mostofa Patwary Committed by Deepak Narayanan
Browse files

Make it possible to pass in tensor shapes to communication methods in p2p_communication.py

parent 3db6517a
...@@ -22,7 +22,9 @@ from megatron import mpu ...@@ -22,7 +22,9 @@ from megatron import mpu
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
use_ring_exchange=False): use_ring_exchange=False, tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Communicate tensors between stages. Used as helper method in other """Communicate tensors between stages. Used as helper method in other
communication methods that are used in megatron/schedules.py. communication methods that are used in megatron/schedules.py.
...@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -37,7 +39,14 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
next rank. next rank.
use_ring_exchange: boolean for whether torch.distributed.ring_exchange() use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
API should be used. API should be used.
tensor_shape: optional, use when the input sequence contains less
tokens than the default sequence length
override_scatter_gather_tensors_in_pipeline: optional, this is used
when tensor_shape is
provided to overwide
scatter gather tensors
dtype_: optional, this is used when tensor_shape is provied and what
is the type of tensor_shape
Returns: Returns:
(tensor_recv_prev, tensor_recv_next) (tensor_recv_prev, tensor_recv_next)
""" """
...@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -47,8 +56,10 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# if needed. # if needed.
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if tensor_shape is None:
if args.scatter_gather_tensors_in_pipeline: tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \ tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
mpu.get_tensor_model_parallel_world_size() mpu.get_tensor_model_parallel_world_size()
else: else:
...@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -56,19 +67,26 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
dtype = args.params_dtype dtype = args.params_dtype
if args.fp32_residual_connection: if args.fp32_residual_connection:
dtype = torch.float dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape, tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
if recv_next: if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape, tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
# Split tensor into smaller chunks if using scatter-gather optimization. # Split tensor into smaller chunks if using scatter-gather optimization.
if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None: if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next) tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
...@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -112,7 +130,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
torch.cuda.synchronize() torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if args.scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and \
args.scatter_gather_tensors_in_pipeline:
if recv_prev: if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor( tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(tensor_shape).requires_grad_()
...@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -124,8 +143,11 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(timers=None): def recv_forward(tensor_shape=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None, timers=None):
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
...@@ -135,7 +157,11 @@ def recv_forward(timers=None): ...@@ -135,7 +157,11 @@ def recv_forward(timers=None):
tensor_send_next=None, tensor_send_next=None,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=True, recv_prev=True,
recv_next=False) recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-recv').stop() timers('forward-recv').stop()
return input_tensor return input_tensor
...@@ -158,8 +184,11 @@ def recv_backward(timers=None): ...@@ -158,8 +184,11 @@ def recv_backward(timers=None):
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, timers=None): def send_forward(output_tensor, timers=None,
override_scatter_gather_tensors_in_pipeline=False,
dtype_=None):
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if not mpu.is_pipeline_last_stage(): if not mpu.is_pipeline_last_stage():
if timers is not None: if timers is not None:
timers('forward-send').start() timers('forward-send').start()
...@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None): ...@@ -167,7 +196,10 @@ def send_forward(output_tensor, timers=None):
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
tensor_send_prev=None, tensor_send_prev=None,
recv_prev=False, recv_prev=False,
recv_next=False) recv_next=False,
override_scatter_gather_tensors_in_pipeline=\
override_scatter_gather_tensors_in_pipeline,
dtype_=dtype_)
if timers is not None: if timers is not None:
timers('forward-send').stop() timers('forward-send').stop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment