Commit cc691cbf authored by Deepak Narayanan's avatar Deepak Narayanan
Browse files

Move unwrap to megatron/utils.py and clean up imports in megatron/schedules.py

parent e3e5ea89
...@@ -21,12 +21,12 @@ import sys ...@@ -21,12 +21,12 @@ import sys
import numpy as np import numpy as np
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from megatron import (get_args, from megatron import (get_args,
mpu, mpu,
print_rank_0, print_rank_0,
update_num_microbatches) update_num_microbatches,
utils)
_CHECKPOINT_VERSION = None _CHECKPOINT_VERSION = None
...@@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler): ...@@ -111,12 +111,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
args = get_args() args = get_args()
# Only rank zero of the data parallel writes to the disk. # Only rank zero of the data parallel writes to the disk.
unwrapped_model = [] model = utils.unwrap_model(model)
for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( print_rank_0('saving checkpoint at iteration {:7d} to {}'.format(
iteration, args.save)) iteration, args.save))
...@@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -220,12 +215,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
args = get_args() args = get_args()
load_dir = getattr(args, load_arg) load_dir = getattr(args, load_arg)
unwrapped_model = [] model = utils.unwrap_model(model)
for model_module in model:
if isinstance(model_module, torchDDP):
model_module = model_module.module
unwrapped_model.append(model_module)
model = unwrapped_model
# Read the tracker file and set the iteration. # Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(load_dir) tracker_filename = get_checkpoint_tracker_filename(load_dir)
...@@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f ...@@ -389,8 +379,7 @@ def load_ict_checkpoint(model, only_query_model=False, only_block_model=False, f
args = get_args() args = get_args()
if isinstance(model, torchDDP): model = utils.unwrap_model(model)
model = model.module
load_path = args.load if from_realm_chkpt else args.ict_load load_path = args.load if from_realm_chkpt else args.ict_load
......
...@@ -16,14 +16,10 @@ ...@@ -16,14 +16,10 @@
import torch import torch
from megatron import get_args from megatron import get_args
from megatron import get_num_microbatches
from megatron import get_timers from megatron import get_timers
from megatron import mpu from megatron import mpu
from megatron import get_num_microbatches from megatron import p2p_communication
from megatron.p2p_communication import recv_forward, recv_backward
from megatron.p2p_communication import send_forward, send_backward
from megatron.p2p_communication import send_forward_recv_backward, send_backward_recv_forward
from megatron.p2p_communication import send_forward_recv_forward, send_backward_recv_backward
from megatron.p2p_communication import send_forward_backward_recv_forward_backward
def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced): def forward_step(forward_step_func, data_iterator, model, input_tensor, losses_reduced):
...@@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -154,7 +150,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Run warmup forward passes. # Run warmup forward passes.
mpu.set_virtual_pipeline_model_parallel_rank(0) mpu.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(recv_forward(timers, use_ring_exchange=True)) input_tensors[0].append(p2p_communication.recv_forward(timers, use_ring_exchange=True))
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
output_tensor = forward_step_helper(k) output_tensor = forward_step_helper(k)
next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True) next_forward_model_chunk_id = get_model_chunk_id(k+1, forward=True)
...@@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -173,13 +169,14 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if mpu.is_pipeline_last_stage(ignore_virtual=True): if mpu.is_pipeline_last_stage(ignore_virtual=True):
recv_next = False recv_next = False
input_tensor, output_tensor_grad = \ input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
timers=timers) timers=timers)
output_tensor_grads[num_model_chunks-1].append(output_tensor_grad) output_tensor_grads[num_model_chunks-1].append(output_tensor_grad)
else: else:
input_tensor = send_forward_recv_forward(output_tensor, recv_prev, timers) input_tensor = \
p2p_communication.send_forward_recv_forward(output_tensor, recv_prev, timers)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -238,7 +235,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
# Communicate tensors. # Communicate tensors.
input_tensor, output_tensor_grad = \ input_tensor, output_tensor_grad = \
send_forward_backward_recv_forward_backward( p2p_communication.send_forward_backward_recv_forward_backward(
output_tensor, input_tensor_grad, output_tensor, input_tensor_grad,
recv_prev=recv_prev, recv_next=recv_next, recv_prev=recv_prev, recv_next=recv_next,
timers=timers) timers=timers)
...@@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -253,7 +250,7 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks-1].append( output_tensor_grads[num_model_chunks-1].append(
recv_backward(timers, use_ring_exchange=True)) p2p_communication.recv_backward(timers, use_ring_exchange=True))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False) next_backward_model_chunk_id = get_model_chunk_id(k+1, forward=False)
...@@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat ...@@ -264,7 +261,8 @@ def forward_backward_pipelining_with_interleaving(forward_step_func, data_iterat
if k == (num_microbatches - 1): if k == (num_microbatches - 1):
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
send_backward_recv_backward(input_tensor_grad, recv_next, timers)) p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next, timers))
return losses_reduced return losses_reduced
...@@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -294,7 +292,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# Run warmup forward passes. # Run warmup forward passes.
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers)
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
# Barrier before first receive to measure forward stall. # Barrier before first receive to measure forward stall.
...@@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -302,7 +300,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
timers('forward-pipeline-stall').start() timers('forward-pipeline-stall').start()
torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group()) torch.distributed.barrier(group=mpu.get_pipeline_model_parallel_group())
timers('forward-pipeline-stall').stop() timers('forward-pipeline-stall').stop()
send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
output_tensors.append(output_tensor) output_tensors.append(output_tensor)
...@@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -317,7 +315,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
# If all microbatches are run in warmup / cooldown phase, then no need to # If all microbatches are run in warmup / cooldown phase, then no need to
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers)
# Run 1F1B in steady state. # Run 1F1B in steady state.
for i in range(num_microbatches_remaining): for i in range(num_microbatches_remaining):
...@@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -326,9 +324,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
output_tensor = forward_step(forward_step_func, data_iterator, model, output_tensor = forward_step(forward_step_func, data_iterator, model,
input_tensor, losses_reduced) input_tensor, losses_reduced)
if forward_only: if forward_only:
send_forward(output_tensor, timers) p2p_communication.send_forward(output_tensor, timers)
else: else:
output_tensor_grad = send_forward_recv_backward(output_tensor, timers) output_tensor_grad = \
p2p_communication.send_forward_recv_backward(output_tensor, timers)
# Add input_tensor and output_tensor to end of list, then pop from the # Add input_tensor and output_tensor to end of list, then pop from the
# start of the list for backward pass. # start of the list for backward pass.
...@@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -337,7 +336,7 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
if forward_only: if forward_only:
if not last_iteration: if not last_iteration:
input_tensor = recv_forward(timers) input_tensor = p2p_communication.recv_forward(timers)
else: else:
input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0) input_tensor, output_tensor = input_tensors.pop(0), output_tensors.pop(0)
...@@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -347,9 +346,10 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers)
else: else:
input_tensor = send_backward_recv_forward(input_tensor_grad, timers) input_tensor = \
p2p_communication.send_backward_recv_forward(input_tensor_grad, timers)
# Run cooldown backward passes. # Run cooldown backward passes.
if not forward_only: if not forward_only:
...@@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model, ...@@ -357,12 +357,12 @@ def forward_backward_pipelining(forward_step_func, data_iterator, model,
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
output_tensor_grad = recv_backward(timers) output_tensor_grad = p2p_communication.recv_backward(timers)
input_tensor_grad = \ input_tensor_grad = \
backward_step(optimizer, input_tensor, output_tensor, backward_step(optimizer, input_tensor, output_tensor,
output_tensor_grad) output_tensor_grad)
send_backward(input_tensor_grad, timers) p2p_communication.send_backward(input_tensor_grad, timers)
return losses_reduced return losses_reduced
...@@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR ...@@ -46,6 +46,7 @@ from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model.realm_model import ICTBertModel from megatron.model.realm_model import ICTBertModel
from megatron.utils import check_adlr_autoresume_termination from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm from megatron.utils import calc_params_l2_norm
from megatron.schedules import forward_backward_no_pipelining from megatron.schedules import forward_backward_no_pipelining
...@@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func): ...@@ -288,9 +289,8 @@ def setup_model_and_optimizer(model_provider_func):
model = get_model(model_provider_func) model = get_model(model_provider_func)
unwrapped_model = model unwrapped_model = unwrap_model(model,
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): (torchDDP, LocalDDP, FP16Module))
unwrapped_model = unwrapped_model.module
optimizer = get_megatron_optimizer(unwrapped_model) optimizer = get_megatron_optimizer(unwrapped_model)
lr_scheduler = get_learning_rate_scheduler(optimizer) lr_scheduler = get_learning_rate_scheduler(optimizer)
...@@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator, ...@@ -370,8 +370,8 @@ def train_step(forward_step_func, data_iterator,
unwrapped_model = model[0] unwrapped_model = model[0]
elif mpu.is_pipeline_last_stage(ignore_virtual=True): elif mpu.is_pipeline_last_stage(ignore_virtual=True):
unwrapped_model = model[-1] unwrapped_model = model[-1]
while isinstance(unwrapped_model, (torchDDP, LocalDDP, FP16Module)): unwrapped_model = unwrap_model(
unwrapped_model = unwrapped_model.module unwrapped_model, (torchDDP, LocalDDP, FP16Module))
if unwrapped_model.share_word_embeddings: if unwrapped_model.share_word_embeddings:
word_embeddings_weight = unwrapped_model.word_embeddings_weight() word_embeddings_weight = unwrapped_model.word_embeddings_weight()
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
import sys import sys
import torch import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
import amp_C import amp_C
...@@ -26,11 +27,25 @@ from megatron import get_args ...@@ -26,11 +27,25 @@ from megatron import get_args
from megatron import print_rank_0 from megatron import print_rank_0
from megatron import get_adlr_autoresume from megatron import get_adlr_autoresume
from megatron import mpu from megatron import mpu
from megatron.checkpointing import save_checkpoint
from megatron.model.module import param_is_not_shared from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model): def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """ """Calculate l2 norm of parameters """
# Remove duplicate params. # Remove duplicate params.
...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration): ...@@ -106,6 +121,8 @@ def print_params_min_max_norm(optimizer, iteration):
def check_adlr_autoresume_termination(iteration, model, def check_adlr_autoresume_termination(iteration, model,
optimizer, lr_scheduler): optimizer, lr_scheduler):
"""Check for autoresume signal and exit if it is received.""" """Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args() args = get_args()
autoresume = get_adlr_autoresume() autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy. # Add barrier to ensure consistnecy.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment