Commit ede0a58f authored by mshoeybi's avatar mshoeybi
Browse files

simplified the iteration read check across ranks

parent a8f4edcb
......@@ -124,26 +124,19 @@ def read_metadata(tracker_filename):
assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename)
# Make sure all the ranks read the same meta data.
iters_cuda = torch.cuda.LongTensor(
torch.distributed.get_world_size()).fill_(0)
iters_cuda[torch.distributed.get_rank()] = iteration
torch.distributed.all_reduce(iters_cuda)
# Get the max iteration retrieved across the ranks.
iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
max_iter = iters_cuda[0].item()
# We should now have all the same iteration.
# If not, print a warning and chose the maximum
# iteration across all ranks.
max_iter = iters_cuda.max().item()
min_iter = iters_cuda.min().item()
if max_iter == min_iter:
print_rank_0('> meta data was loaded successfully ...')
else:
for rank in range(torch.distributed.get_world_size()):
if iters_cuda[rank] != max_iters:
print_rank_0('WARNING: on rank {} found iteration {} in the '
'meta data while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iters_cuda[rank], max_iter))
if iteration != max_iter:
print('WARNING: on rank {} found iteration {} in the '
'metadata while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iteration, max_iter), flush=True)
return max_iter, release
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment