Commit 0404891e authored by zhe chen's avatar zhe chen
Browse files

Fix bug in newer slurm system

parent d36b7c67
......@@ -581,7 +581,9 @@ if __name__ == '__main__':
assert has_native_amp, 'Please update pytorch(1.6+) to support amp!'
# init distributed env
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE']) != 1:
# In the newer versions of Slurm, the format of `SLURM_TASKS_PER_NODE` has changed from a single
# numeric string to a format like `8(xn)`, which represents n nodes is used in the training.
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE'][0]) != 1:
print('\nDist init: SLURM')
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
......
......@@ -497,7 +497,9 @@ if __name__ == '__main__':
args, config = parse_option()
# init distributed env
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE']) != 1:
# In the newer versions of Slurm, the format of `SLURM_TASKS_PER_NODE` has changed from a single
# numeric string to a format like `8(xn)`, which represents n nodes is used in the training.
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE'][0]) != 1:
print('\nDist init: SLURM')
rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment