"vscode:/vscode.git/clone" did not exist on "541ce716b9dacae94ef8cd99f179df16b9f73d84"
Commit 3e6898e6 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Move training schedule to 1F1B for memory efficiency

parent 6abf39be
...@@ -357,6 +357,60 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten ...@@ -357,6 +357,60 @@ def backward_step_with_communication(optimizer, model, input_tensors, output_ten
timers('backward-send').stop() timers('backward-send').stop()
def forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_microbatch,
input_tensors, output_tensors,
losses_reduced, timers):
# Forward model for one step.
timers('forward-compute').start()
output_tensor = forward_step_func(data_iterator, model, input_tensor)
timers('forward-compute').stop()
if mpu.is_pipeline_last_stage():
loss, loss_reduced = output_tensor
output_tensor = loss
output_tensor_grad = None
losses_reduced.append(loss_reduced)
else:
timers('forward-send').start()
timers('backward-recv').start()
_, output_tensor_grad = communicate(
tensor_send_next=output_tensor,
tensor_send_prev=None,
recv_forward=False,
recv_backward=True)
timers('forward-send').stop()
timers('backward-recv').stop()
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
# Backward pass for one step.
timers('backward-compute').start()
input_grad_tensor = \
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
if not mpu.is_pipeline_first_stage():
timers('backward-send').start()
timers('forward-recv').start()
input_tensor, _ = communicate(
tensor_send_next=None,
tensor_send_prev=input_grad_tensor,
recv_forward=(not last_microbatch),
recv_backward=False)
timers('backward-send').stop()
timers('forward-recv').stop()
else:
input_tensor = None
return input_tensor
def train_step(forward_step_func, data_iterator, def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler): model, optimizer, lr_scheduler):
"""Single training step.""" """Single training step."""
...@@ -371,18 +425,12 @@ def train_step(forward_step_func, data_iterator, ...@@ -371,18 +425,12 @@ def train_step(forward_step_func, data_iterator,
# Compute number of microbatches in a minibatch. # Compute number of microbatches in a minibatch.
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
# For now, perform training without warmup. Perform forward num_warmup_microbatches = \
# passes for all microbatches, then backward passes for all (mpu.get_pipeline_model_parallel_world_size() -
# microbatches. mpu.get_pipeline_model_parallel_rank() - 1)
# TODO: Switch to the following schedule to facilitate more num_warmup_microbatches = min(
# memory-efficient training. num_warmup_microbatches,
# num_warmup_microbatches = \ num_microbatches_in_minibatch)
# (torch.distributed.get_world_size(group=mpu.get_pipeline_model_parallel_group()) -
# torch.distributed.get_rank(group=mpu.get_pipeline_model_parallel_group()) - 1)
# num_warmup_microbatches = min(
# num_warmup_microbatches,
# num_microbatches_in_minibatch)
num_warmup_microbatches = num_microbatches_in_minibatch
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
...@@ -407,6 +455,26 @@ def train_step(forward_step_func, data_iterator, ...@@ -407,6 +455,26 @@ def train_step(forward_step_func, data_iterator,
timers('forward-compute').stop() timers('forward-compute').stop()
timers('forward').stop() timers('forward').stop()
# Before running 1F1B, need to receive first forward tensor.
if (num_microbatches_in_minibatch - num_warmup_microbatches) > 0:
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_forward=True,
recv_backward=False)
# Run 1F1B.
for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
last_iteration = (i == (num_microbatches_in_minibatch - num_warmup_microbatches - 1))
input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
input_tensor, last_iteration,
input_tensors, output_tensors,
losses_reduced, timers)
# Run cooldown backward passes. # Run cooldown backward passes.
timers('backward').start() timers('backward').start()
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment