Commit 0722f878 authored by mshoeybi's avatar mshoeybi
Browse files

added across rank sync for checkpoint iteration laod, fixed type for timing,...

added across rank sync for checkpoint iteration laod, fixed type for timing, and validation iterations
parent 91fcbd68
......@@ -106,6 +106,47 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def read_metadata(tracker_filename):
# Read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Make sure all the ranks read the same meta data.
iters_cuda = torch.cuda.LongTensor(
torch.distributed.get_world_size()).fill_(0)
iters_cuda[torch.distributed.get_rank()] = iteration
torch.distributed.all_reduce(iters_cuda)
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
max_iter = iters_cuda.max().item()
min_iter = iters_cuda.min().item()
if max_iter == min_iter:
print_rank_0('> meta data was loaded successfully ...')
else:
for rank in range(torch.distributed.get_world_size()):
if iters_cuda[rank] != max_iters:
print_rank_0('WARNING: on rank {} found iteration {} in the '
'meta data while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iters_cuda[rank], max_iter))
return max_iter, release
def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint."""
args = get_args()
......@@ -260,21 +301,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
iteration, release = read_metadata(tracker_filename)
# Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
......
......@@ -96,7 +96,7 @@ def pretrain(train_valid_test_dataset_provider,
# This will be closer to what scheduler will see (outside of
# image ... launches.
global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item()
......@@ -787,10 +787,9 @@ def build_train_valid_test_data_iterators(
'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \
'only backward compatiblity support for iteration-based training'
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
if args.train_samples is None:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment