Commit 0404891e authored by zhe chen's avatar zhe chen
Browse files

Fix bug in newer slurm system

parent d36b7c67
...@@ -581,7 +581,9 @@ if __name__ == '__main__': ...@@ -581,7 +581,9 @@ if __name__ == '__main__':
assert has_native_amp, 'Please update pytorch(1.6+) to support amp!' assert has_native_amp, 'Please update pytorch(1.6+) to support amp!'
# init distributed env # init distributed env
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE']) != 1: # In the newer versions of Slurm, the format of `SLURM_TASKS_PER_NODE` has changed from a single
# numeric string to a format like `8(xn)`, which represents n nodes is used in the training.
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE'][0]) != 1:
print('\nDist init: SLURM') print('\nDist init: SLURM')
rank = int(os.environ['SLURM_PROCID']) rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count() gpu = rank % torch.cuda.device_count()
......
...@@ -497,7 +497,9 @@ if __name__ == '__main__': ...@@ -497,7 +497,9 @@ if __name__ == '__main__':
args, config = parse_option() args, config = parse_option()
# init distributed env # init distributed env
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE']) != 1: # In the newer versions of Slurm, the format of `SLURM_TASKS_PER_NODE` has changed from a single
# numeric string to a format like `8(xn)`, which represents n nodes is used in the training.
if 'SLURM_PROCID' in os.environ and int(os.environ['SLURM_TASKS_PER_NODE'][0]) != 1:
print('\nDist init: SLURM') print('\nDist init: SLURM')
rank = int(os.environ['SLURM_PROCID']) rank = int(os.environ['SLURM_PROCID'])
gpu = rank % torch.cuda.device_count() gpu = rank % torch.cuda.device_count()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment