Commit 318d68c2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Refactor communication code in main training loop to helper method

parent e805f0bd
...@@ -59,6 +59,8 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -59,6 +59,8 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min( args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size)) (args.world_size // args.tensor_model_parallel_size))
if args.num_microbatches_in_minibatch is None:
args.num_microbatches_in_minibatch = 1
if args.rank == 0: if args.rank == 0:
print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format( print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format(
args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size)) args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size))
...@@ -223,6 +225,8 @@ def _add_training_args(parser): ...@@ -223,6 +225,8 @@ def _add_training_args(parser):
help='Batch size per model instance (local batch size). ' help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data ' 'Global batch size is local batch size times data '
'parallel size.') 'parallel size.')
group.add_argument('--num-microbatches-in-minibatch', type=int, default=None,
help='Number of microbatches in minibatch')
group.add_argument('--checkpoint-activations', action='store_true', group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training ' help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.') 'with larger models, sequences, and batch sizes.')
...@@ -368,8 +372,6 @@ def _add_distributed_args(parser): ...@@ -368,8 +372,6 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.') help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1, group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.') help='Degree of pipeline model parallelism.')
group.add_argument('--use-pipelining', action='store_true',
help='Use pipelining to increase throughput of pipeline model parallelism')
group.add_argument('--distributed-backend', default='nccl', group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'], choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.') help='Which backend to use for distributed training.')
......
...@@ -138,7 +138,7 @@ def get_model(model_provider_func): ...@@ -138,7 +138,7 @@ def get_model(model_provider_func):
model = FP16_Module(model) model = FP16_Module(model)
# Wrap model for distributed training.""" # Wrap model for distributed training."""
if args.use_pipelining: if args.num_microbatches_in_minibatch > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
if args.DDP_impl == 'torch': if args.DDP_impl == 'torch':
...@@ -291,6 +291,67 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g ...@@ -291,6 +291,67 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
return input_tensor_grad return input_tensor_grad
def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model for one step.
timers('forward').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
losses_reduced.append(loss_reduced)
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
"""Backward step."""
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward').stop()
if not mpu.is_pipeline_first_stage():
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, lr_scheduler):
"""Single training step.""" """Single training step."""
...@@ -304,70 +365,41 @@ def train_step(forward_step_func, data_iterator, ...@@ -304,70 +365,41 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad() optimizer.zero_grad()
# Compute number of microbatches in a minibatch. # Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline = args.pipeline_model_parallel_size \ num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
if args.use_pipelining else 1 # TODO: Switch to the following schedule when async communication is supported
# so that we can facilitate mroe memory-efficient training.
# num_warmup_microbatches = \
# (torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) -
# torch.distributed.get_rank(group=mpu.get_pipeline_model_parallel_group()) - 1)
# num_warmup_microbatches = min(
# num_warmup_microbatches,
# num_microbatches_in_minibatch)
num_warmup_microbatches = num_microbatches_in_minibatch
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
losses_reduced = [] losses_reduced = []
# Run forward pass for all microbatches in minibatch. # Run warmup forward passes.
for i in range(num_microbatches_to_pipeline): for i in range(num_warmup_microbatches):
if not mpu.is_pipeline_first_stage(): forward_step_with_communication(
input_tensor, _ = communicate( forward_step_func, data_iterator, model,
tensor_send_next=None, input_tensors, output_tensors,
tensor_send_prev=None, losses_reduced, timers)
recv_forward=True,
recv_backward=False) # Run 1F1B.
else: for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
input_tensor = None forward_step_with_communication(
forward_step_func, data_iterator, model,
# Forward model for one step. input_tensors, output_tensors,
timers('forward').start() losses_reduced, timers)
output_tensor = forward_step_func(data_iterator, model, input_tensor) backward_step_with_communication(
timers('forward').stop() optimizer, model, input_tensors, output_tensors, timers)
if mpu.is_pipeline_last_stage(): # Run cooldown backward passes.
loss, loss_reduced = output_tensor for i in range(num_warmup_microbatches):
output_tensor = loss backward_step_with_communication(
losses_reduced.append(loss_reduced) optimizer, model, input_tensors, output_tensors, timers)
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Run backward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_grad_tensor = None
else:
_, output_grad_tensor = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor)
timers('backward').stop()
if not mpu.is_pipeline_first_stage():
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment