"tests/vscode:/vscode.git/clone" did not exist on "e87a04e3c7c16aeb28ccd0cba693aa7c5ebaab48"
Commit 318d68c2 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Refactor communication code in main training loop to helper method

parent e805f0bd
......@@ -59,6 +59,8 @@ def parse_args(extra_args_provider=None, defaults={},
args.pipeline_model_parallel_size = min(
args.pipeline_model_parallel_size,
(args.world_size // args.tensor_model_parallel_size))
if args.num_microbatches_in_minibatch is None:
args.num_microbatches_in_minibatch = 1
if args.rank == 0:
print('using world size: {}, tensor-model-parallel size: {}, pipeline-model-parallel size: {} '.format(
args.world_size, args.tensor_model_parallel_size, args.pipeline_model_parallel_size))
......@@ -223,6 +225,8 @@ def _add_training_args(parser):
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size.')
group.add_argument('--num-microbatches-in-minibatch', type=int, default=None,
help='Number of microbatches in minibatch')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
......@@ -368,8 +372,6 @@ def _add_distributed_args(parser):
help='Degree of tensor model parallelism.')
group.add_argument('--pipeline-model-parallel-size', type=int, default=1,
help='Degree of pipeline model parallelism.')
group.add_argument('--use-pipelining', action='store_true',
help='Use pipelining to increase throughput of pipeline model parallelism')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
......
......@@ -138,7 +138,7 @@ def get_model(model_provider_func):
model = FP16_Module(model)
# Wrap model for distributed training."""
if args.use_pipelining:
if args.num_microbatches_in_minibatch > 1:
assert args.DDP_impl == 'local'
if args.DDP_impl == 'torch':
......@@ -291,6 +291,67 @@ def backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_g
return input_tensor_grad
def forward_step_with_communication(forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model for one step.
timers('forward').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
losses_reduced.append(loss_reduced)
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
def backward_step_with_communication(optimizer, model, input_tensors, output_tensors, timers):
"""Backward step."""
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward').stop()
if not mpu.is_pipeline_first_stage():
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
......@@ -304,70 +365,41 @@ def train_step(forward_step_func, data_iterator,
optimizer.zero_grad()
# Compute number of microbatches in a minibatch.
num_microbatches_to_pipeline = args.pipeline_model_parallel_size \
if args.use_pipelining else 1
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
# TODO: Switch to the following schedule when async communication is supported
# so that we can facilitate mroe memory-efficient training.
# num_warmup_microbatches = \
# (torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) -
# torch.distributed.get_rank(group=mpu.get_pipeline_model_parallel_group()) - 1)
# num_warmup_microbatches = min(
# num_warmup_microbatches,
# num_microbatches_in_minibatch)
num_warmup_microbatches = num_microbatches_in_minibatch
input_tensors = []
output_tensors = []
losses_reduced = []
# Run forward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
if not mpu.is_pipeline_first_stage():
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
else:
input_tensor = None
# Forward model for one step.
timers('forward').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
losses_reduced.append(loss_reduced)
else:
communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=False)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
# Run backward pass for all microbatches in minibatch.
for i in range(num_microbatches_to_pipeline):
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
if mpu.is_pipeline_last_stage():
output_grad_tensor = None
else:
_, output_grad_tensor = communicate(
tensor_send_next=None,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
# Backward pass for one step.
# TODO: This timer is a bit redundant now with backward-backward.
timers('backward').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_grad_tensor)
timers('backward').stop()
if not mpu.is_pipeline_first_stage():
communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=False,
recv_backward=False)
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
# Run 1F1B.
for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
# All-reduce if needed.
if args.DDP_impl == 'local':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment