Commit 1ad1e1b1 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'nmt-main' into 'main'

Support for variable sequence lengths across micro-batches

See merge request ADLR/megatron-lm!472
parents b4297c6a c3e688d3
...@@ -186,6 +186,13 @@ def validate_args(args, defaults={}): ...@@ -186,6 +186,13 @@ def validate_args(args, defaults={}):
args.consumed_train_samples = 0 args.consumed_train_samples = 0
args.consumed_valid_samples = 0 args.consumed_valid_samples = 0
# Support for variable sequence lengths across batches/microbatches.
# set it if the dataloader supports generation of variable sequence lengths
# across batches/microbatches. Due to additional communication overhead
# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
# Iteration-based training. # Iteration-based training.
if args.train_iters: if args.train_iters:
# If we use iteration-based training, make sure the # If we use iteration-based training, make sure the
...@@ -883,7 +890,7 @@ def _add_data_args(parser): ...@@ -883,7 +890,7 @@ def _add_data_args(parser):
help="Maximum decoder sequence length to process.") help="Maximum decoder sequence length to process.")
group.add_argument('--retriever-seq-length', type=int, default=256, group.add_argument('--retriever-seq-length', type=int, default=256,
help='Maximum sequence length for the biencoder model ' help='Maximum sequence length for the biencoder model '
' for retriever') 'for retriever')
group.add_argument('--sample-rate', type=float, default=1.0, group.add_argument('--sample-rate', type=float, default=1.0,
help='sample rate for training data. Supposed to be 0 ' help='sample rate for training data. Supposed to be 0 '
' < sample_rate < 1') ' < sample_rate < 1')
......
...@@ -8,6 +8,96 @@ from megatron import get_args, core ...@@ -8,6 +8,96 @@ from megatron import get_args, core
from megatron.core import mpu from megatron.core import mpu
def _communicate_shapes(tensor_send_next, tensor_send_prev,
recv_prev, recv_next):
"""Communicate tensor shapes between stages. Used to communicate
tensor shapes before the actual tensor communication happens.
This is required when the sequence lengths across micro batches
are not uniform.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(recv_prev_shape, recv_next_shape)
"""
args = get_args()
recv_prev_shape_tensor = None
recv_next_shape_tensor = None
send_prev_shape_tensor = None
send_next_shape_tensor = None
if recv_prev:
recv_prev_shape_tensor = torch.empty((3),
device=torch.cuda.current_device(),
dtype=torch.int64)
if recv_next:
recv_next_shape_tensor = torch.empty((3),
device=torch.cuda.current_device(),
dtype=torch.int64)
if tensor_send_prev is not None:
send_prev_shape_tensor = torch.tensor(tensor_send_prev.size(),
device=torch.cuda.current_device(),
dtype=torch.int64)
if tensor_send_next is not None:
send_next_shape_tensor = torch.tensor(tensor_send_next.size(),
device=torch.cuda.current_device(),
dtype=torch.int64)
if args.use_ring_exchange_p2p:
torch.distributed.ring_exchange(tensor_send_prev=send_prev_shape_tensor,
tensor_recv_prev=recv_prev_shape_tensor,
tensor_send_next=send_next_shape_tensor,
tensor_recv_next=recv_next_shape_tensor,
group=mpu.get_pipeline_model_parallel_group())
else:
ops = []
if send_prev_shape_tensor is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend, send_prev_shape_tensor,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(send_prev_op)
if recv_prev_shape_tensor is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_prev_shape_tensor,
mpu.get_pipeline_model_parallel_prev_rank())
ops.append(recv_prev_op)
if send_next_shape_tensor is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, send_next_shape_tensor,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(send_next_op)
if recv_next_shape_tensor is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv, recv_next_shape_tensor,
mpu.get_pipeline_model_parallel_next_rank())
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
# should take this out once the bug with batch_isend_irecv is resolved.
torch.cuda.synchronize()
recv_prev_shape = [0, 0, 0]
if recv_prev_shape_tensor is not None:
recv_prev_shape = recv_prev_shape_tensor.tolist()
recv_next_shape = [0, 0, 0]
if recv_next_shape_tensor is not None:
recv_next_shape = recv_next_shape_tensor.tolist()
return recv_prev_shape, recv_next_shape
def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape, tensor_shape,
dtype_=None): dtype_=None):
...@@ -41,21 +131,39 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -41,21 +131,39 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Some legacy inference code doesn't set the tensor shape, do so now # Some legacy inference code doesn't set the tensor shape, do so now
# for the normal values for gpt/bert. This could be removed if inference # for the normal values for gpt/bert. This could be removed if inference
# code is changed to provide tensor_shape. # code is changed to provide tensor_shape.
if tensor_shape is None: if not args.variable_seq_lengths:
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) if tensor_shape is None:
recv_prev_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
recv_next_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
else:
recv_prev_shape = tensor_shape
recv_next_shape = tensor_shape
else:
recv_prev_shape, recv_next_shape = \
_communicate_shapes(tensor_send_next,
tensor_send_prev,
recv_prev,
recv_next)
override_scatter_gather_tensors_in_pipeline = False override_scatter_gather_tensors_in_pipeline = False
if args.scatter_gather_tensors_in_pipeline and \ if args.scatter_gather_tensors_in_pipeline and \
not args.sequence_parallel: not args.sequence_parallel:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) recv_prev_chunk_shape = reduce(operator.mul, recv_prev_shape, 1)
if tensor_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0: recv_next_chunk_shape = reduce(operator.mul, recv_next_shape, 1)
tensor_chunk_shape = tensor_chunk_shape // \ if recv_prev_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0 and \
recv_next_chunk_shape % mpu.get_tensor_model_parallel_world_size() == 0:
recv_prev_chunk_shape = recv_prev_chunk_shape // \
mpu.get_tensor_model_parallel_world_size()
recv_next_chunk_shape = recv_next_chunk_shape // \
mpu.get_tensor_model_parallel_world_size() mpu.get_tensor_model_parallel_world_size()
else: else:
tensor_chunk_shape = tensor_shape recv_prev_chunk_shape = recv_prev_shape
recv_next_chunk_shape = recv_next_shape
override_scatter_gather_tensors_in_pipeline = True override_scatter_gather_tensors_in_pipeline = True
else: else:
tensor_chunk_shape = tensor_shape recv_prev_chunk_shape = recv_prev_shape
recv_next_chunk_shape = recv_next_shape
dtype = args.params_dtype dtype = args.params_dtype
if args.fp32_residual_connection: if args.fp32_residual_connection:
dtype = torch.float dtype = torch.float
...@@ -66,12 +174,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -66,12 +174,12 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
requires_grad = False requires_grad = False
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty(tensor_chunk_shape, tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
requires_grad=requires_grad, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
if recv_next: if recv_next:
tensor_recv_next = torch.empty(tensor_chunk_shape, tensor_recv_next = torch.empty(recv_next_chunk_shape,
requires_grad=requires_grad, requires_grad=requires_grad,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
...@@ -128,17 +236,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -128,17 +236,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
not args.sequence_parallel: not args.sequence_parallel:
if recv_prev: if recv_prev:
tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor( tensor_recv_prev = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_() tensor_recv_prev).view(recv_prev_shape).requires_grad_()
tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev, tensor_recv_prev = core.utils.make_viewless_tensor(tensor_recv_prev,
requires_grad = True, requires_grad=True,
keep_graph = False) keep_graph=False)
if recv_next: if recv_next:
tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor( tensor_recv_next = core.tensor_parallel.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_() tensor_recv_next).view(recv_next_shape).requires_grad_()
tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next, tensor_recv_next = core.utils.make_viewless_tensor(tensor_recv_next,
requires_grad = True, requires_grad=True,
keep_graph = False) keep_graph=False)
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment