Commit ede0a58f authored by mshoeybi's avatar mshoeybi
Browse files

simplified the iteration read check across ranks

parent a8f4edcb
...@@ -124,26 +124,19 @@ def read_metadata(tracker_filename): ...@@ -124,26 +124,19 @@ def read_metadata(tracker_filename):
assert iteration > 0 or release, 'error parsing metadata file {}'.format( assert iteration > 0 or release, 'error parsing metadata file {}'.format(
tracker_filename) tracker_filename)
# Make sure all the ranks read the same meta data. # Get the max iteration retrieved across the ranks.
iters_cuda = torch.cuda.LongTensor( iters_cuda = torch.cuda.LongTensor([iteration])
torch.distributed.get_world_size()).fill_(0) torch.distributed.all_reduce(iters_cuda, op=torch.distributed.ReduceOp.MAX)
iters_cuda[torch.distributed.get_rank()] = iteration max_iter = iters_cuda[0].item()
torch.distributed.all_reduce(iters_cuda)
# We should now have all the same iteration. # We should now have all the same iteration.
# If not, print a warning and chose the maximum # If not, print a warning and chose the maximum
# iteration across all ranks. # iteration across all ranks.
max_iter = iters_cuda.max().item() if iteration != max_iter:
min_iter = iters_cuda.min().item() print('WARNING: on rank {} found iteration {} in the '
if max_iter == min_iter: 'metadata while max iteration across the ranks '
print_rank_0('> meta data was loaded successfully ...') 'is {}, replacing it with max iteration.'.format(
else: rank, iteration, max_iter), flush=True)
for rank in range(torch.distributed.get_world_size()):
if iters_cuda[rank] != max_iters:
print_rank_0('WARNING: on rank {} found iteration {} in the '
'meta data while max iteration across the ranks '
'is {}, replacing it with max iteration.'.format(
rank, iters_cuda[rank], max_iter))
return max_iter, release return max_iter, release
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment