"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "537d37c26c619a9aea48bcef44c1d8f45d6d7b1a"
Commit 6ee3d210 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'leftover_fixes' into 'main'

added across ranks sync for checkpoint iteration load and couple of other fixes

See merge request ADLR/megatron-lm!304
parents 91fcbd68 0722f878
...@@ -106,6 +106,47 @@ def get_checkpoint_tracker_filename(checkpoints_path): ...@@ -106,6 +106,47 @@ def get_checkpoint_tracker_filename(checkpoints_path):
return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt')
def read_metadata(tracker_filename):
# Read the tracker file and either set the iteration or
# mark it as a release checkpoint.
iteration = 0
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Make sure all the ranks read the same meta data.
iters_cuda = torch.cuda.LongTensor(
torch.distributed.get_world_size()).fill_(0)
iters_cuda[torch.distributed.get_rank()] = iteration
torch.distributed.all_reduce(iters_cuda)
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
max_iter = iters_cuda.max().item()
min_iter = iters_cuda.min().item()
if max_iter == min_iter:
print_rank_0('> meta data was loaded successfully ...')
else:
for rank in range(torch.distributed.get_world_size()):
if iters_cuda[rank] != max_iters:
print_rank_0('WARNING: on rank {} found iteration {} in the '
'meta data while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iters_cuda[rank], max_iter))
return max_iter, release
def save_checkpoint(iteration, model, optimizer, lr_scheduler): def save_checkpoint(iteration, model, optimizer, lr_scheduler):
"""Save a model checkpoint.""" """Save a model checkpoint."""
args = get_args() args = get_args()
...@@ -260,21 +301,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -260,21 +301,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
# Otherwise, read the tracker file and either set the iteration or # Otherwise, read the tracker file and either set the iteration or
# mark it as a release checkpoint. # mark it as a release checkpoint.
iteration = 0 iteration, release = read_metadata(tracker_filename)
release = False
with open(tracker_filename, 'r') as f:
metastring = f.read().strip()
try:
iteration = int(metastring)
except ValueError:
release = metastring == 'release'
if not release:
print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format(
tracker_filename))
sys.exit()
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Checkpoint. # Checkpoint.
checkpoint_name = get_checkpoint_name(load_dir, iteration, release) checkpoint_name = get_checkpoint_name(load_dir, iteration, release)
......
...@@ -96,7 +96,7 @@ def pretrain(train_valid_test_dataset_provider, ...@@ -96,7 +96,7 @@ def pretrain(train_valid_test_dataset_provider,
# This will be closer to what scheduler will see (outside of # This will be closer to what scheduler will see (outside of
# image ... launches. # image ... launches.
global _TRAIN_START_TIME global _TRAIN_START_TIME
start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME]) start_time_tensor = torch.cuda.DoubleTensor([_TRAIN_START_TIME])
torch.distributed.all_reduce(start_time_tensor, torch.distributed.all_reduce(start_time_tensor,
op=torch.distributed.ReduceOp.MIN) op=torch.distributed.ReduceOp.MIN)
_TRAIN_START_TIME = start_time_tensor.item() _TRAIN_START_TIME = start_time_tensor.item()
...@@ -787,10 +787,9 @@ def build_train_valid_test_data_iterators( ...@@ -787,10 +787,9 @@ def build_train_valid_test_data_iterators(
'only backward compatiblity support for iteration-based training' 'only backward compatiblity support for iteration-based training'
args.consumed_train_samples = args.iteration * args.global_batch_size args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0: if args.iteration > 0 and args.consumed_valid_samples == 0:
assert args.train_samples is None, \ if args.train_samples is None:
'only backward compatiblity support for iteration-based training' args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \ args.eval_iters * args.global_batch_size
args.eval_iters * args.global_batch_size
# Data loader only on rank 0 of each model parallel group. # Data loader only on rank 0 of each model parallel group.
if mpu.get_tensor_model_parallel_rank() == 0: if mpu.get_tensor_model_parallel_rank() == 0:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment