Commit 27fc4689 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Break up tensors sent between pipeline stages into smaller chunks that can be all-gathered

parent 8e922d5b
...@@ -566,6 +566,8 @@ def _add_distributed_args(parser): ...@@ -566,6 +566,8 @@ def _add_distributed_args(parser):
choices=['local', 'torch'], choices=['local', 'torch'],
help='which DistributedDataParallel implementation ' help='which DistributedDataParallel implementation '
'to use.') 'to use.')
group.add_argument('--scatter-gather-tensors-in-pipeline', action='store_true',
help='Use scatter/gather to optimize communication of tensors in pipeline')
group.add_argument('--local_rank', type=int, default=None, group.add_argument('--local_rank', type=int, default=None,
help='local rank passed from distributed launcher.') help='local rank passed from distributed launcher.')
group.add_argument('--lazy-mpu-init', type=bool, required=False, group.add_argument('--lazy-mpu-init', type=bool, required=False,
......
...@@ -59,6 +59,8 @@ from .random import get_cuda_rng_tracker ...@@ -59,6 +59,8 @@ from .random import get_cuda_rng_tracker
from .random import init_checkpointed_activations_memory_buffer from .random import init_checkpointed_activations_memory_buffer
from .random import model_parallel_cuda_manual_seed from .random import model_parallel_cuda_manual_seed
from .random import reset_checkpointed_activations_memory_buffer from .random import reset_checkpointed_activations_memory_buffer
from .random import gather_split_1d_tensor
from .random import split_tensor_into_1d_equal_chunks
from .utils import divide from .utils import divide
from .utils import split_tensor_along_last_dim from .utils import split_tensor_along_last_dim
...@@ -29,20 +29,33 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -29,20 +29,33 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_recv_prev = None tensor_recv_prev = None
tensor_recv_next = None tensor_recv_next = None
tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size) tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
if args.scatter_gather_tensors_in_pipeline:
tensor_chunk_shape = (
args.seq_length * args.micro_batch_size * args.hidden_size) // \
mpu.get_tensor_model_parallel_world_size()
else:
tensor_chunk_shape = tensor_shape
dtype = args.params_dtype dtype = args.params_dtype
if args.fp32_residual_connection: if args.fp32_residual_connection:
dtype = torch.float dtype = torch.float
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty(tensor_shape, tensor_recv_prev = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
if recv_next: if recv_next:
tensor_recv_next = torch.empty(tensor_shape, tensor_recv_next = torch.empty(tensor_chunk_shape,
requires_grad=True, requires_grad=True,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=dtype) dtype=dtype)
if args.scatter_gather_tensors_in_pipeline:
if tensor_send_next is not None:
tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
if tensor_send_prev is not None:
tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
if use_ring_exchange: if use_ring_exchange:
torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev, torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
...@@ -71,6 +84,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next, ...@@ -71,6 +84,17 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
reqs = torch.distributed.batch_isend_irecv(ops) reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
torch.cuda.synchronize()
tensor_recv_prev_before = tensor_recv_prev
if args.scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = mpu.gather_split_1d_tensor(
tensor_recv_prev).view(tensor_shape).requires_grad_()
if recv_next:
tensor_recv_next = mpu.gather_split_1d_tensor(
tensor_recv_next).view(tensor_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -381,11 +381,11 @@ def train_step(forward_step_func, data_iterator, ...@@ -381,11 +381,11 @@ def train_step(forward_step_func, data_iterator,
# Update parameters. # Update parameters.
timers('optimizer').start() timers('optimizer').start()
update_successfull, grad_norm = optimizer.step() update_successful, grad_norm = optimizer.step()
timers('optimizer').stop() timers('optimizer').stop()
# Update learning rate. # Update learning rate.
if update_successfull: if update_successful:
increment = get_num_microbatches() * \ increment = get_num_microbatches() * \
args.micro_batch_size * \ args.micro_batch_size * \
args.data_parallel_size args.data_parallel_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment