Commit 1979c242 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Refactor code according to Jared's comments: move pipelining and...

Refactor code according to Jared's comments: move pipelining and non-pipelining training loops into separate methods

Also, use mpu.get_*_model_parallel_size() instead of args.*_model_parallel_size
parent 9ff6f473
......@@ -18,7 +18,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
args = get_args()
assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel"
assert args.tensor_model_parallel_size == 1 and args.pipeline_model_parallel_size == 1, \
assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...')
......
......@@ -505,9 +505,9 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers.
assert args.num_layers % args.pipeline_model_parallel_size == 0, \
assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size'
self.num_layers = args.num_layers // args.pipeline_model_parallel_size
self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
# Transformer layers.
def build_layer(layer_number):
......
......@@ -409,19 +409,34 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
return input_tensor
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run forward and backward passes without inter-stage communication."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
else:
optimizer.zero_grad()
losses_reduced = []
for i in range(args.num_microbatches_in_minibatch):
timers('forward-compute').start()
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
output_tensor = loss
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
timers('backward-compute').start()
output_tensor_grad = None
backward_step(optimizer, model, input_tensor=None,
output_tensor=output_tensor, output_tensor_grad=None)
timers('backward-compute').stop()
return losses_reduced
# Compute number of microbatches in a minibatch.
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args = get_args()
# Compute number of warmup microbatches.
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() -
......@@ -429,6 +444,8 @@ def train_step(forward_step_func, data_iterator,
num_warmup_microbatches = min(
num_warmup_microbatches,
num_microbatches_in_minibatch)
num_microbatches_in_minibatch_remaining = \
num_microbatches_in_minibatch - num_warmup_microbatches
input_tensors = []
output_tensors = []
......@@ -436,23 +453,15 @@ def train_step(forward_step_func, data_iterator,
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
if args.pipeline_model_parallel_size > 1:
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
else:
timers('forward-compute').start()
input_tensor = None
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor)
output_tensor = loss
losses_reduced.append(loss_reduced)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
timers('forward-compute').stop()
forward_step_with_communication(
forward_step_func, data_iterator, model,
input_tensors, output_tensors,
losses_reduced, timers)
# Before running 1F1B, need to receive first forward tensor.
if (num_microbatches_in_minibatch - num_warmup_microbatches) > 0:
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_in_minibatch_remaining > 0:
if mpu.is_pipeline_first_stage():
input_tensor = None
else:
......@@ -464,8 +473,8 @@ def train_step(forward_step_func, data_iterator,
timers('forward-recv').stop()
# Run 1F1B.
for i in range(num_microbatches_in_minibatch - num_warmup_microbatches):
last_iteration = (i == (num_microbatches_in_minibatch - num_warmup_microbatches - 1))
for i in range(num_microbatches_in_minibatch_remaining):
last_iteration = (i == (num_microbatches_in_minibatch_remaining - 1))
input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer,
......@@ -475,16 +484,30 @@ def train_step(forward_step_func, data_iterator,
# Run cooldown backward passes.
for i in range(num_warmup_microbatches):
if args.pipeline_model_parallel_size > 1:
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
else:
timers('backward-compute').start()
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = None
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad)
timers('backward-compute').stop()
backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers)
return losses_reduced
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
else:
optimizer.zero_grad()
if mpu.get_pipeline_model_parallel_world_size() > 1:
losses_reduced = forward_backward_pipelining(
forward_step_func, data_iterator, model, optimizer, timers)
else:
losses_reduced = forward_backward_no_pipelining(
forward_step_func, data_iterator, model, optimizer, timers)
# All-reduce if needed.
if args.DDP_impl == 'local':
......@@ -499,7 +522,7 @@ def train_step(forward_step_func, data_iterator,
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
args.pipeline_model_parallel_size > 1:
mpu.get_pipeline_model_parallel_world_size() > 1:
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module
......
......@@ -34,7 +34,7 @@ def model_provider():
print_rank_0('building BERT model ...')
args = get_args()
if args.pipeline_model_parallel_size > 1:
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = BertModelFirstStage(
......
......@@ -33,7 +33,7 @@ def model_provider():
print_rank_0('building GPT2 model ...')
args = get_args()
if args.pipeline_model_parallel_size > 1:
if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment