Commit 86e7d624 authored by Mohammad's avatar Mohammad
Browse files

refactored checkpooints but not tested yet

parent 11220df8
# coding=utf-8
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Input/output checkpointing."""
import os
import random
import sys
import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from megatron import mpu
from .global_vars import get_args
from .utils import print_rank_0
def check_checkpoint_args(checkpoint_args):
"""Ensure fixed arguments for a model are the same for the input
arguments and the one retreived frm checkpoint."""
args = get_args()
def _compare(arg_name):
checkpoint_value = getattr(checkpoint_args, arg_name)
args_value = getattr(args, arg_name)
error_message = '{} value from checkpoint ({}) is not equal to the ' \
'input argument value ({}).'.format(
arg_name, checkpoint_value, args_value)
assert checkpoint_value == args_value, error_message
_compare('num_layers')
_compare('hidden_size')
_compare('num_attention_heads')
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
_compare('tokenizer_type')
_compare('model_parallel_size')
def ensure_directory_exists(filename):
"""Build filename's path if it does not already exists."""
dirname = os.path.dirname(filename)
if not os.path.exists(dirname):
os.makedirs(dirname)
def get_checkpoint_name(checkpoints_path, iteration,
release=False, mp_rank=None):
"""A unified checkpoint name."""
if release:
directory = 'release'
else:
directory = 'iter_{:07d}'.format(iteration)
return os.path.join(checkpoints_path, directory,
'mp_rank_{:02d}'.format(
mpu.get_model_parallel_rank() if mp_rank is None \
else mp_rank),
'model_optim_rng.pt')
def get_checkpoint_tracker_filename(checkpoints_path):
"""Tracker file rescords the latest chckpoint during
training to restart from."""
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args()
# Only rank zero of the data parallel writes to the disk.
if isinstance(model, torchDDP):
model = model.module
if mpu.get_data_parallel_rank() == 0:
# Arguments, iteration, and model.
state_dict = {}
state_dict['args'] = args
state_dict['iteration'] = iteration
state_dict['model'] = model.state_dict_for_save_checkpoint()
# Optimizer stuff.
if not args.no_save_optim:
if optimizer is not None:
state_dict['optimizer'] = optimizer.state_dict()
if lr_scheduler is not None:
state_dict['lr_scheduler'] = lr_scheduler.state_dict()
# RNG states.
if not args.no_save_rng:
state_dict['random_rng_state'] = random.getstate()
state_dict['np_rng_state'] = np.random.get_state()
state_dict['torch_rng_state'] = torch.get_rng_state()
state_dict['cuda_rng_state'] = torch.cuda.get_rng_state()
state_dict['rng_tracker_states'] \
= mpu.get_cuda_rng_tracker().get_states()
# Save.
checkpoint_name = get_checkpoint_name(args.save, iteration)
print('global rank {} is saving checkpoint at iteration {:7d} to {}'.
format(torch.distributed.get_rank(), iteration, checkpoint_name))
ensure_directory_exists(checkpoint_name)
torch.save(state_dict, checkpoint_name)
print(' successfully saved {}'.format(checkpoint_name))
# Wait so everyone is done (necessary)
torch.distributed.barrier()
# And update the latest iteration
if torch.distributed.get_rank() == 0:
tracker_filename = get_checkpoint_tracker_filename(args.save)
with open(tracker_filename, 'w') as f:
f.write(str(iteration))
# Wait so everyone is done (not necessary)
torch.distributed.barrier()
def load_checkpoint(model, optimizer, lr_scheduler):
"""Load a model checkpoint and return the iteration."""
args = get_args()
if isinstance(model, torchDDP):
model = model.module
# Read the tracker file and set the iteration.
tracker_filename = get_checkpoint_tracker_filename(args.load)
# If no tracker file, return iretation zero.
if not os.path.isfile(tracker_filename):
print_rank_0('WARNING: could not find the metadata file {} '.format(
tracker_filename))
print_rank_0(' will not load any checkpoints and will start from '
'random')
return 0
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(args.load, iteration, release)
if mpu.get_data_parallel_rank() == 0:
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))
# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
except ModuleNotFoundError:
# For backward compatibility.
print_rank_0(' > deserializing using the old code structure ...')
sys.modules['fp16.loss_scaler'] = sys.modules[
'megatron.fp16.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
sys.modules.pop('fp16.loss_scaler', None)
except:
print_rank_0('could not load the checkpoint')
sys.exit()
# Set iteration.
if args.finetune or release:
iteration = 0
else:
try:
iteration = state_dict['iteration']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
except KeyError:
print_rank_0('A metadata file exists but unable to load '
'iteration from checkpoint {}, exiting'.format(
checkpoint_name))
sys.exit()
# Check arguments.
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
else:
print_rank_0('could not find arguments in the checkpoint ...')
# Model.
model.load_state_dict(state_dict['model'])
# Optimizer.
if not release and not args.finetune and not args.no_load_optim:
try:
if optimizer is not None:
optimizer.load_state_dict(state_dict['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(state_dict['lr_scheduler'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-optim or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
# rng states.
if not release and not args.finetune and not args.no_load_rng:
try:
random.setstate(state_dict['random_rng_state'])
np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states'])
except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, '
'exiting ...'.format(checkpoint_name))
sys.exit()
torch.distributed.barrier()
if mpu.get_data_parallel_rank() == 0:
print(' successfully loaded {}'.format(checkpoint_name))
return iteration
......@@ -22,15 +22,11 @@ import numpy as np
import torch
from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
#from megatron.global_vars import get_args
#from megatron.global_vars import get_adlr_autoresume
from megatron.global_vars import get_args
from megatron.global_vars import get_adlr_autoresume
from megatron import mpu
from megatron.fp16 import FP16_Module
from megatron.fp16 import FP16_Optimizer
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.model import get_params_for_weight_decay_optimization
def print_rank_0(message):
......@@ -52,6 +48,41 @@ def reduce_losses(losses):
return reduced_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
#######################################
def check_adlr_autoresume_termination(iteration, model, optimizer,
lr_scheduler, args):
# Add barrier to ensure consistnecy.
......@@ -66,8 +97,6 @@ def check_adlr_autoresume_termination(iteration, model, optimizer,
exit(0)
def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
......@@ -125,42 +154,6 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, model-parallel,min, max, norm\n'
optimizer_ = optimizer
if isinstance(optimizer, FP16_Optimizer):
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = param.data.norm()
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | cached: {}'.format(torch.cuda.memory_cached() / mega_bytes)
string += ' | max cached: {}'.format(
torch.cuda.max_memory_cached()/ mega_bytes)
print_rank_0(string)
def vocab_size_with_padding(num_tokens, args):
after = num_tokens
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment