Commit 27fc4689 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Break up tensors sent between pipeline stages into smaller chunks that can be all-gathered

parent 8e922d5b
......@@ -566,6 +566,8 @@ def _add_distributed_args(parser):
choices=['local', 'torch'],
help='which DistributedDataParallel implementation '
'to use.')
group.add_argument('--scatter-gather-tensors-in-pipeline', action='store_true',
help='Use scatter/gather to optimize communication of tensors in pipeline')
group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False,
......
......@@ -59,6 +59,8 @@ from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .utils import divide
from .utils import split_tensor_along_last_dim
......@@ -29,20 +29,33 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_recv_prev = None
tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (
args.seq_length * args.micro_batch_size * args.hidden_size) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype
if args.fp32_residual_connection:
dtype = torch.float
if recv_prev:
tensor_recv_prev = torch.empty(tensor_shape,
tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if recv_next:
tensor_recv_next = torch.empty(tensor_shape,
tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True,
device=torch.cuda.current_device(),
dtype=dtype)
if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
......@@ -71,6 +84,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
tensor_recv_prev_before = tensor_recv_prev
if args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
......
......@@ -381,11 +381,11 @@ def train_step(forward_step_func, data_iterator,
# Update parameters.
timers('optimizer').start()
update_successfull, grad_norm = optimizer.step()
update_successful, grad_norm = optimizer.step()
timers('optimizer').stop()
# Update learning rate.
if update_successfull:
if update_successful:
increment = get_num_microbatches() * \
args.micro_batch_size * \
args.data_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment