Commit 82dac9c9 authored by Roshan Rao's avatar Roshan Rao Committed by mcarilli
Browse files

update _amp_state to check distributed on maybe_print (#620)

parent 37cdaf4a
...@@ -35,12 +35,9 @@ def warn_or_err(msg): ...@@ -35,12 +35,9 @@ def warn_or_err(msg):
# "hard_override=True to amp.initialize.") # "hard_override=True to amp.initialize.")
distributed = False
if 'WORLD_SIZE' in os.environ:
distributed = int(os.environ['WORLD_SIZE']) > 1
def maybe_print(msg, rank0=False): def maybe_print(msg, rank0=False):
distributed = torch.distributed.is_initialized() and \
torch.distributed.get_world_size() > 1
if _amp_state.verbosity > 0: if _amp_state.verbosity > 0:
if rank0: if rank0:
if distributed: if distributed:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment