"doc/vscode:/vscode.git/clone" did not exist on "d04b570fb576cac75682ab5246358b04565e9a71"
Commit 626645c0 authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Comments in megatron/schedules.py and address a few more comments

parent cc691cbf
...@@ -23,7 +23,12 @@ from megatron import p2p_communication ...@@ -23,7 +23,12 @@ from megatron import p2p_communication
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
"""Forward step.""" """Forward step for passed-in model.
If first stage, input tensor is obtained from data_iterator, otherwise
passed-in input_tensor is used.
Returns output tensor."""
timers = get_timers() timers = get_timers()
timers('forward-compute').start() timers('forward-compute').start()
...@@ -38,7 +43,13 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r ...@@ -38,7 +43,13 @@ def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_r
def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
"""Backward step.""" """Backward step through passed-in output tensor.
If last stage, output_tensor_grad is None, otherwise gradient of loss
with respect to stage's output tensor.
Returns gradient of loss with respect to input tensor (None if first
stage)."""
args = get_args() args = get_args()
timers = get_timers() timers = get_timers()
...@@ -65,7 +76,10 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad): ...@@ -65,7 +76,10 @@ def backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad):
def forward_backward_no_pipelining(forward_step_func, data_iterator, model, def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only): optimizer, timers, forward_only):
"""Run forward and backward passes without inter-stage communication.""" """Run forward and backward passes with no pipeline parallelism
(no inter-stage communication).
Returns dictionary with losses."""
assert len(model) == 1 assert len(model) == 1
model = model[0] model = model[0]
...@@ -83,7 +97,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model, ...@@ -83,7 +97,10 @@ def forward_backward_no_pipelining(forward_step_func, data_iterator, model,
def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model, def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterator, model,
optimizer, timers, forward_only): optimizer, timers, forward_only):
"""Run interleaved 1F1B schedule.""" """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
input_tensors = [[] for _ in range(len(model))] input_tensors = [[] for _ in range(len(model))]
output_tensors = [[] for _ in range(len(model))] output_tensors = [[] for _ in range(len(model))]
losses_reduced = [] losses_reduced = []
...@@ -100,18 +117,27 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -100,18 +117,27 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if forward_only: if forward_only:
num_warmup_microbatches = num_microbatches num_warmup_microbatches = num_microbatches
else: else:
# Run all forward passes and then all backward passes if number of
# microbatches is just the number of pipeline stages.
# Otherwise, perform (num_model_chunks-1)*pipeline_parallel_size on
# all workers, followed by more microbatches after depending on
# stage ID (more forward passes for earlier stages, later stages can
# immediately start with 1F1B).
if get_num_microbatches() == pipeline_parallel_size: if get_num_microbatches() == pipeline_parallel_size:
num_warmup_microbatches = num_microbatches num_warmup_microbatches = num_microbatches
all_warmup_microbatches = True all_warmup_microbatches = True
else: else:
num_warmup_microbatches = \ num_warmup_microbatches = \
(pipeline_parallel_size - pipeline_parallel_rank - 1) * 2 (pipeline_parallel_size - pipeline_parallel_rank - 1) * 2
num_warmup_microbatches += (num_model_chunks - 1) * pipeline_parallel_size num_warmup_microbatches += (
num_warmup_microbatches = min(num_warmup_microbatches, num_microbatches) num_model_chunks - 1) * pipeline_parallel_size
num_warmup_microbatches = min(num_warmup_microbatches,
num_microbatches)
num_microbatches_remaining = \ num_microbatches_remaining = \
num_microbatches - num_warmup_microbatches num_microbatches - num_warmup_microbatches
def get_model_chunk_id(k, forward): def get_model_chunk_id(k, forward):
"""Helper method to get the model chunk ID given the iteration number."""
k_in_group = k % (pipeline_parallel_size * num_model_chunks) k_in_group = k % (pipeline_parallel_size * num_model_chunks)
i = k_in_group // pipeline_parallel_size i = k_in_group // pipeline_parallel_size
if not forward: if not forward:
...@@ -119,14 +145,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -119,14 +145,19 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return i return i
def forward_step_helper(k): def forward_step_helper(k):
"""Helper method to run forward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
forward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=True) model_chunk_id = get_model_chunk_id(k, forward=True)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
if mpu.is_pipeline_first_stage(): if mpu.is_pipeline_first_stage():
if len(input_tensors[model_chunk_id]) == len(output_tensors[model_chunk_id]): if len(input_tensors[model_chunk_id]) == \
len(output_tensors[model_chunk_id]):
input_tensors[model_chunk_id].append(None) input_tensors[model_chunk_id].append(None)
input_tensor = input_tensors[model_chunk_id][-1] input_tensor = input_tensors[model_chunk_id][-1]
output_tensor = forward_step(forward_step_func, data_iterator[model_chunk_id], output_tensor = forward_step(forward_step_func,
data_iterator[model_chunk_id],
model[model_chunk_id], model[model_chunk_id],
input_tensor, losses_reduced) input_tensor, losses_reduced)
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
...@@ -134,6 +165,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -134,6 +165,9 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
return output_tensor return output_tensor
def backward_step_helper(k): def backward_step_helper(k):
"""Helper method to run backward step with model split into chunks
(run set_virtual_pipeline_model_parallel_rank() before calling
backward_step())."""
model_chunk_id = get_model_chunk_id(k, forward=False) model_chunk_id = get_model_chunk_id(k, forward=False)
mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id) mpu.set_virtual_pipeline_model_parallel_rank(model_chunk_id)
...@@ -144,15 +178,21 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -144,15 +178,21 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad) backward_step(optimizer,
input_tensor,
output_tensor,
output_tensor_grad)
return input_tensor_grad return input_tensor_grad
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(timers, use_ring_exchange=True)) input_tensors[0].append(
p2p_communication.recv_forward(timers, use_ring_exchange=True))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
# Determine if tensor should be received from previous stage.
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
recv_prev = True recv_prev = True
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
...@@ -160,8 +200,13 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -160,8 +200,13 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev = False recv_prev = False
if k == (num_microbatches - 1): if k == (num_microbatches - 1):
recv_prev = False recv_prev = False
# Don't send tensor downstream if on last stage.
if mpu.is_pipeline_last_stage(): if mpu.is_pipeline_last_stage():
output_tensor = None output_tensor = None
# Send and receive tensors as appropriate (send tensors computed
# in this iteration; receive tensors for next iteration).
if k == (num_warmup_microbatches - 1) and not forward_only and \ if k == (num_warmup_microbatches - 1) and not forward_only and \
not all_warmup_microbatches: not all_warmup_microbatches:
input_tensor_grad = None input_tensor_grad = None
...@@ -176,7 +221,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -176,7 +221,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_forward_recv_forward(output_tensor, recv_prev, timers) p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -215,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -215,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev = False recv_prev = False
next_forward_model_chunk_id += 1 next_forward_model_chunk_id += 1
else: else:
next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1, forward=True) next_forward_model_chunk_id = get_model_chunk_id(forward_k + 1,
forward=True)
recv_next = True recv_next = True
if mpu.is_pipeline_last_stage(ignore_virtual=True): if mpu.is_pipeline_last_stage(ignore_virtual=True):
...@@ -226,10 +273,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -226,10 +273,11 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_next = False recv_next = False
next_backward_model_chunk_id -= 1 next_backward_model_chunk_id -= 1
else: else:
next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(backward_k + 1,
forward=False)
# If last iteration, don't receive; we already received one extra before the # If last iteration, don't receive; we already received one extra
# start of the for loop. # before the start of the for loop.
if k == (num_microbatches_remaining - 1): if k == (num_microbatches_remaining - 1):
recv_prev = False recv_prev = False
...@@ -240,13 +288,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -240,13 +288,15 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
timers=timers) timers=timers)
# Put input_tensor and output_tensor_grad in data structures in the right location. # Put input_tensor and output_tensor_grad in data structures in the
# right location.
if recv_prev: if recv_prev:
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
if recv_next: if recv_next:
output_tensor_grads[next_backward_model_chunk_id].append(output_tensor_grad) output_tensor_grads[next_backward_model_chunk_id].append(
output_tensor_grad)
# Run cooldown backward passes. # Run cooldown backward passes (flush out pipeline).
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
...@@ -269,7 +319,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -269,7 +319,10 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
def forward_backward_pipelining(forward_step_func, data_iterator, model, def forward_backward_pipelining(forward_step_func, data_iterator, model,
optimizer, timers, forward_only): optimizer, timers, forward_only):
"""Run 1F1B schedule, with communication and warmup + cooldown microbatches as needed.""" """Run non-interleaved 1F1B schedule, with communication between pipeline
stages.
Returns dictionary with losses if the last stage, empty dict otherwise."""
timers = get_timers() timers = get_timers()
assert len(model) == 1 assert len(model) == 1
...@@ -327,7 +380,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -327,7 +380,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
p2p_communication.send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers)
else: else:
output_tensor_grad = \ output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, timers) p2p_communication.send_forward_recv_backward(output_tensor,
timers)
# Add input_tensor and output_tensor to end of list, then pop from the # Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass. # start of the list for backward pass.
...@@ -349,7 +403,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -349,7 +403,8 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
p2p_communication.send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers)
else: else:
input_tensor = \ input_tensor = \
p2p_communication.send_backward_recv_forward(input_tensor_grad, timers) p2p_communication.send_backward_recv_forward(
input_tensor_grad, timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
......
...@@ -120,13 +120,13 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -120,13 +120,13 @@ def pretrain(train_valid_test_dataset_provider,
# Data stuff. # Data stuff.
timers('train/valid/test-data-iterators-setup').start() timers('train/valid/test-data-iterators-setup').start()
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
data_iterators = [ all_data_iterators = [
build_train_valid_test_data_iterators(train_valid_test_dataset_provider) build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
for _ in range(len(model)) for _ in range(len(model))
] ]
train_data_iterator = [x[0] for x in data_iterators] train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
valid_data_iterator = [x[1] for x in data_iterators] valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
test_data_iterator = [x[2] for x in data_iterators] test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
else: else:
train_data_iterator, valid_data_iterator, test_data_iterator \ train_data_iterator, valid_data_iterator, test_data_iterator \
= build_train_valid_test_data_iterators( = build_train_valid_test_data_iterators(
...@@ -311,17 +311,18 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -311,17 +311,18 @@ def setup_model_and_optimizer(model_provider_func):
# We only support local DDP with multiple micro-batches. # We only support local DDP with multiple micro-batches.
if get_num_microbatches() > 1: if get_num_microbatches() > 1:
assert args.DDP_impl == 'local' assert args.DDP_impl == 'local'
if len(model) == 1:
assert args.DDP_impl == 'local'
if mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
# get model without FP16 and/or TorchDDP wrappers # get model without FP16 and/or TorchDDP wrappers
model = unwrap_model(model)
for module in model: for module in model:
unwrapped_module = module if args.iteration == 0 and hasattr(module,
while hasattr(unwrapped_module, 'module'):
unwrapped_module = unwrapped_module.module
if args.iteration == 0 and hasattr(unwrapped_module,
'init_state_dict_from_bert'): 'init_state_dict_from_bert'):
print("Initializing ICT from pretrained BERT model", flush=True) print("Initializing ICT from pretrained BERT model", flush=True)
unwrapped_module.init_state_dict_from_bert() module.init_state_dict_from_bert()
return model, optimizer, lr_scheduler return model, optimizer, lr_scheduler
...@@ -364,7 +365,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -364,7 +365,8 @@ def train_step(forward_step_func, data_iterator,
# This should only run for models that support pipelined model parallelism # This should only run for models that support pipelined model parallelism
# (BERT and GPT-2). # (BERT and GPT-2).
timers('backward-embedding-all-reduce').start() timers('backward-embedding-all-reduce').start()
if (mpu.is_pipeline_first_stage(ignore_virtual=True) or mpu.is_pipeline_last_stage(ignore_virtual=True)) and \ if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
mpu.get_pipeline_model_parallel_world_size() > 1: mpu.get_pipeline_model_parallel_world_size() > 1:
if mpu.is_pipeline_first_stage(ignore_virtual=True): if mpu.is_pipeline_first_stage(ignore_virtual=True):
unwrapped_model = model[0] unwrapped_model = model[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment