Commit 626645c0 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Comments in megatron/schedules.py and address a few more comments

parent cc691cbf
......@@ -23,7 +23,12 @@ from megatron import p2p_communication
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step."""
"""Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers = get_timers()
timers('forward-compute').start()
......@@ -38,7 +43,13 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step."""
"""Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args = get_args()
timers = get_timers()
......@@ -65,7 +76,10 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run forward and backward passes without inter-stage communication."""
"""Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert len(model) == 1
model = model[0]
......@@ -83,7 +97,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run interleaved 1F1B schedule."""
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))]
losses_reduced = []
......@@ -100,18 +117,27 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if forward_only:
num_warmup_microbatches = num_microbatches
else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True
else:
num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches)
num_warmup_microbatches += (
num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches
def get_model_chunk_id(k, forward):
"""Helper method to get the model chunk ID given the iteration number."""
k_in_group = k % (pipeline_parallel_size * num_model_chunks)
i = k_in_group // pipeline_parallel_size
if not forward:
......@@ -119,14 +145,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return i
def forward_step_helper(k):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]):
if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func, data_iterator[model_chunk_id],
output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id],
input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor)
......@@ -134,6 +165,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return output_tensor
def backward_step_helper(k):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
......@@ -144,15 +178,21 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)
backward_step(optimizer,
input_tensor,
output_tensor,
output_tensor_grad)
return input_tensor_grad
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(timers, use_ring_exchange=True))
input_tensors[0].append(
p2p_communication.recv_forward(timers, use_ring_exchange=True))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True):
......@@ -160,8 +200,13 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev = False
if k == (num_microbatches - 1):
recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage():
output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches:
input_tensor_grad = None
......@@ -176,7 +221,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = \
p2p_communication.send_forward_recv_forward(output_tensor, recv_prev, timers)
p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
......@@ -215,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev = False
next_forward_model_chunk_id += 1
else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True)
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
forward=True)
recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True):
......@@ -226,10 +273,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False
next_backward_model_chunk_id -= 1
else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False)
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
forward=False)
# If last iteration, don't receive; we already received one extra before the
# start of the for loop.
# If last iteration, don't receive; we already received one extra
# before the start of the for loop.
if k == (num_microbatches_remaining - 1):
recv_prev = False
......@@ -240,13 +288,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the right location.
# Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad)
output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
# Run cooldown backward passes.
# Run cooldown backward passes (flush out pipeline).
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
......@@ -269,7 +319,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed."""
"""Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers = get_timers()
assert len(model) == 1
......@@ -327,7 +380,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
p2p_communication.send_forward(output_tensor, timers)
else:
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, timers)
p2p_communication.send_forward_recv_backward(output_tensor,
timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
......@@ -349,7 +403,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
p2p_communication.send_backward(input_tensor_grad, timers)
else:
input_tensor = \
p2p_communication.send_backward_recv_forward(input_tensor_grad, timers)
p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers)
# Run cooldown backward passes.
if not forward_only:
......
......@@ -120,13 +120,13 @@ def pretrain(train_valid_test_dataset_provider,
# Data stuff.
timers('train/valid/test-data-iterators-setup').start()
if args.virtual_pipeline_model_parallel_size is not None:
data_iterators = [
all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
for _ in range(len(model))
]
train_data_iterator = [x[0] for x in data_iterators]
valid_data_iterator = [x[1] for x in data_iterators]
test_data_iterator = [x[2] for x in data_iterators]
train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
else:
train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators(
......@@ -311,17 +311,18 @@ def setup_model_and_optimizer(model_provider_func):
# We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1:
assert args.DDP_impl == 'local'
if len(model) == 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers
model = unwrap_model(model)
for module in model:
unwrapped_module = module
while hasattr(unwrapped_module, 'module'):
unwrapped_module = unwrapped_module.module
if args.iteration == 0 and hasattr(unwrapped_module,
if args.iteration == 0 and hasattr(module,
'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_module.init_state_dict_from_bert()
module.init_state_dict_from_bert()
return model, optimizer, lr_scheduler
......@@ -364,7 +365,8 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment