Commit cc691cbf authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Move unwrap to megatron/utils.py and clean up imports in megatron/schedules.py

parent e3e5ea89
......@@ -21,12 +21,12 @@ import sys
import numpy as np
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import (get_args,
mpu,
print_rank_0,
update_num_microbatches)
update_num_microbatches,
utils)
_CHECKPOINT_VERSION = None
......@@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args()
# Only rank zero of the data parallel writes to the disk.
unwrapped_model = []
for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
model = utils.unwrap_model(model)
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save))
......@@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args = get_args()
load_dir = getattr(args, load_arg)
unwrapped_model = []
for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
model = utils.unwrap_model(model)
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir)
......@@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
args = get_args()
if isinstance(model, torchDDP):
model = model.module
model = utils.unwrap_model(model)
load_path = args.load if from_realm_chkpt else args.ict_load
......
......@@ -16,14 +16,10 @@
import torch
from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers
from megatron import mpu
from megatron import get_num_microbatches
from megatron.p2p_communication import recv_forward, recv_backward
from megatron.p2p_communication import send_forward, send_backward
from megatron.p2p_communication import send_forward_recv_backward, send_backward_recv_forward
from megatron.p2p_communication import send_forward_recv_forward, send_backward_recv_backward
from megatron.p2p_communication import send_forward_backward_recv_forward_backward
from megatron import p2p_communication
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
......@@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(recv_forward(timers, use_ring_exchange=True))
input_tensors[0].append(p2p_communication.recv_forward(timers, use_ring_exchange=True))
for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k)
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
......@@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else:
input_tensor = send_forward_recv_forward(output_tensor, recv_prev, timers)
input_tensor = \
p2p_communication.send_forward_recv_forward(output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state.
......@@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Communicate tensors.
input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward(
p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next,
timers=timers)
......@@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append(
recv_backward(timers, use_ring_exchange=True))
p2p_communication.recv_backward(timers, use_ring_exchange=True))
for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
......@@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if k == (num_microbatches - 1):
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
send_backward_recv_backward(input_tensor_grad, recv_next, timers))
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers))
return losses_reduced
......@@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# Run warmup forward passes.
for i in range(num_warmup_microbatches):
input_tensor = recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall.
......@@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop()
send_forward(output_tensor, timers)
p2p_communication.send_forward(output_tensor, timers)
input_tensors.append(input_tensor)
output_tensors.append(output_tensor)
......@@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here.
if num_microbatches_remaining > 0:
input_tensor = recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers)
# Run 1F1B in steady state.
for i in range(num_microbatches_remaining):
......@@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced)
if forward_only:
send_forward(output_tensor, timers)
p2p_communication.send_forward(output_tensor, timers)
else:
output_tensor_grad = send_forward_recv_backward(output_tensor, timers)
output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, timers)
# Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass.
......@@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
if forward_only:
if not last_iteration:
input_tensor = recv_forward(timers)
input_tensor = p2p_communication.recv_forward(timers)
else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
......@@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
if last_iteration:
input_tensor = None
send_backward(input_tensor_grad, timers)
p2p_communication.send_backward(input_tensor_grad, timers)
else:
input_tensor = send_backward_recv_forward(input_tensor_grad, timers)
input_tensor = \
p2p_communication.send_backward_recv_forward(input_tensor_grad, timers)
# Run cooldown backward passes.
if not forward_only:
......@@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(timers)
output_tensor_grad = p2p_communication.recv_backward(timers)
input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad)
send_backward(input_tensor_grad, timers)
p2p_communication.send_backward(input_tensor_grad, timers)
return losses_reduced
......@@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining
......@@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func)
unwrapped_model = model
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
unwrapped_model = unwrapped_model.module
unwrapped_model = unwrap_model(model,
(torchDDP, LocalDDP, FP16Module))
optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer)
......@@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1]
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)):
unwrapped_model = unwrapped_model.module
unwrapped_model = unwrap_model(
unwrapped_model, (torchDDP, LocalDDP, FP16Module))
if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight()
......
......@@ -18,6 +18,7 @@
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
......@@ -26,11 +27,25 @@ from megatron import get_args
from megatron import print_rank_0
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
# Remove duplicate params.
......@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment