Commit 1979c242 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Refactor code according to Jared's comments: move pipelining and...

Refactor code according to Jared's comments: move pipelining and non-pipelining training loops into separate methods

Also, use mpu.get_*_model_parallel_size() instead of args.*_model_parallel_size
parent 9ff6f473
...@@ -18,7 +18,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False): ...@@ -18,7 +18,7 @@ def general_ict_model_provider(only_query_model=False, only_block_model=False):
args = get_args() args = get_args()
assert args.ict_head_size is not None, \ assert args.ict_head_size is not None, \
"Need to specify --ict-head-size to provide an ICTBertModel" "Need to specify --ict-head-size to provide an ICTBertModel"
assert args.tensor_model_parallel_size == 1 and args.pipeline_model_parallel_size == 1, \ assert mpu.get_tensor_model_parallel_world_size() == 1 and mpu.get_pipeline_model_parallel_world_size() == 1, \
"Model parallel size > 1 not supported for ICT" "Model parallel size > 1 not supported for ICT"
print_rank_0('building ICTBertModel...') print_rank_0('building ICTBertModel...')
......
...@@ -505,9 +505,9 @@ class ParallelTransformer(MegatronModule): ...@@ -505,9 +505,9 @@ class ParallelTransformer(MegatronModule):
self.checkpoint_num_layers = args.checkpoint_num_layers self.checkpoint_num_layers = args.checkpoint_num_layers
# Number of layers. # Number of layers.
assert args.num_layers % args.pipeline_model_parallel_size == 0, \ assert args.num_layers % mpu.get_pipeline_model_parallel_world_size() == 0, \
'num_layers must be divisible by pipeline_model_parallel_size' 'num_layers must be divisible by pipeline_model_parallel_size'
self.num_layers = args.num_layers // args.pipeline_model_parallel_size self.num_layers = args.num_layers // mpu.get_pipeline_model_parallel_world_size()
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
......
...@@ -409,19 +409,34 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat ...@@ -409,19 +409,34 @@ def forward_and_backward_steps_with_communication(forward_step_func, data_iterat
return input_tensor return input_tensor
def train_step(forward_step_func, data_iterator, def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
model, optimizer, lr_scheduler): optimizer, timers):
"""Single training step.""" """Run forward and backward passes without inter-stage communication."""
args = get_args() args = get_args()
timers = get_timers()
# Set grad to zero. losses_reduced = []
if args.fp16: for i in range(args.num_microbatches_in_minibatch):
optimizer.zero_grad(set_grads_to_None=True) timers('forward-compute').start()
else: loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor=None)
optimizer.zero_grad() output_tensor = loss
losses_reduced.append(loss_reduced)
timers('forward-compute').stop()
timers('backward-compute').start()
output_tensor_grad = None
backward_step(optimizer, model, input_tensor=None,
output_tensor=output_tensor, output_tensor_grad=None)
timers('backward-compute').stop()
return losses_reduced
# Compute number of microbatches in a minibatch. def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
args = get_args()
# Compute number of warmup microbatches.
num_microbatches_in_minibatch = args.num_microbatches_in_minibatch num_microbatches_in_minibatch = args.num_microbatches_in_minibatch
num_warmup_microbatches = \ num_warmup_microbatches = \
(mpu.get_pipeline_model_parallel_world_size() - (mpu.get_pipeline_model_parallel_world_size() -
...@@ -429,6 +444,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -429,6 +444,8 @@ def train_step(forward_step_func, data_iterator,
num_warmup_microbatches = min( num_warmup_microbatches = min(
num_warmup_microbatches, num_warmup_microbatches,
num_microbatches_in_minibatch) num_microbatches_in_minibatch)
num_microbatches_in_minibatch_remaining = \
num_microbatches_in_minibatch - num_warmup_microbatches
input_tensors = [] input_tensors = []
output_tensors = [] output_tensors = []
...@@ -436,23 +453,15 @@ def train_step(forward_step_func, data_iterator, ...@@ -436,23 +453,15 @@ def train_step(forward_step_func, data_iterator,
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
if args.pipeline_model_parallel_size > 1:
forward_step_with_communication( forward_step_with_communication(
forward_step_func, data_iterator, model, forward_step_func, data_iterator, model,
input_tensors, output_tensors, input_tensors, output_tensors,
losses_reduced, timers) losses_reduced, timers)
else:
timers('forward-compute').start()
input_tensor = None
loss, loss_reduced = forward_step_func(data_iterator, model, input_tensor)
output_tensor = loss
losses_reduced.append(loss_reduced)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
timers('forward-compute').stop()
# Before running 1F1B, need to receive first forward tensor. # Before running 1F1B, need to receive first forward tensor.
if (num_microbatches_in_minibatch - num_warmup_microbatches) > 0: # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_in_minibatch_remaining > 0:
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
input_tensor = None input_tensor = None
else: else:
...@@ -464,8 +473,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -464,8 +473,8 @@ def train_step(forward_step_func, data_iterator,
timers('forward-recv').stop() timers('forward-recv').stop()
# Run 1F1B. # Run 1F1B.
for i in range(num_microbatches_in_minibatch - num_warmup_microbatches): for i in range(num_microbatches_in_minibatch_remaining):
last_iteration = (i == (num_microbatches_in_minibatch - num_warmup_microbatches - 1)) last_iteration = (i == (num_microbatches_in_minibatch_remaining - 1))
input_tensor = \ input_tensor = \
forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model, forward_and_backward_steps_with_communication(forward_step_func, data_iterator, model,
optimizer, optimizer,
...@@ -475,16 +484,30 @@ def train_step(forward_step_func, data_iterator, ...@@ -475,16 +484,30 @@ def train_step(forward_step_func, data_iterator,
# Run cooldown backward passes. # Run cooldown backward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
if args.pipeline_model_parallel_size > 1:
backward_step_with_communication( backward_step_with_communication(
optimizer, model, input_tensors, output_tensors, timers) optimizer, model, input_tensors, output_tensors, timers)
return losses_reduced
def train_step(forward_step_func, data_iterator,
model, optimizer, lr_scheduler):
"""Single training step."""
args = get_args()
timers = get_timers()
# Set grad to zero.
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
else: else:
timers('backward-compute').start() optimizer.zero_grad()
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) if mpu.get_pipeline_model_parallel_world_size() > 1:
output_tensor_grad = None losses_reduced = forward_backward_pipelining(
backward_step(optimizer, model, input_tensor, output_tensor, output_tensor_grad) forward_step_func, data_iterator, model, optimizer, timers)
timers('backward-compute').stop() else:
losses_reduced = forward_backward_no_pipelining(
forward_step_func, data_iterator, model, optimizer, timers)
# All-reduce if needed. # All-reduce if needed.
if args.DDP_impl == 'local': if args.DDP_impl == 'local':
...@@ -499,7 +522,7 @@ def train_step(forward_step_func, data_iterator, ...@@ -499,7 +522,7 @@ def train_step(forward_step_func, data_iterator,
# (BERT and GPT-2). # (BERT and GPT-2).
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \ if (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()) and \
args.pipeline_model_parallel_size > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
unwrapped_model = model unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)): while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16_Module)):
unwrapped_model = unwrapped_model.module unwrapped_model = unwrapped_model.module
......
...@@ -34,7 +34,7 @@ def model_provider(): ...@@ -34,7 +34,7 @@ def model_provider():
print_rank_0('building BERT model ...') print_rank_0('building BERT model ...')
args = get_args() args = get_args()
if args.pipeline_model_parallel_size > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = BertModelFirstStage( model = BertModelFirstStage(
......
...@@ -33,7 +33,7 @@ def model_provider(): ...@@ -33,7 +33,7 @@ def model_provider():
print_rank_0('building GPT2 model ...') print_rank_0('building GPT2 model ...')
args = get_args() args = get_args()
if args.pipeline_model_parallel_size > 1: if mpu.get_pipeline_model_parallel_world_size() > 1:
# Determine model based on position of stage in pipeline. # Determine model based on position of stage in pipeline.
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
model = GPT2ModelFirstStage(num_tokentypes=0) model = GPT2ModelFirstStage(num_tokentypes=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment