"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c44fba889965638f447d20f5730745c7963494d7"
Unverified Commit 69048ff0 authored by Harry's avatar Harry Committed by GitHub
Browse files

Specifying distributed training port in os.environ when training with slurm (#362)

* feat: support for os.environ port for slurm training

* fix: port data type

* feat: add flawed unittest

* feat: add flawed unittest

* docs: add comments

* fix: unittest

* fix: unittest
parent 8e80223e
...@@ -35,7 +35,17 @@ def _init_dist_mpi(backend, **kwargs): ...@@ -35,7 +35,17 @@ def _init_dist_mpi(backend, **kwargs):
raise NotImplementedError raise NotImplementedError
def _init_dist_slurm(backend, port=29500): def _init_dist_slurm(backend, port=None):
"""Initialize slurm distributed training environment.
If argument ``port`` is not specified, then the master port will be system
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
environment variable, then a default port ``29500`` will be used.
Args:
backend (str): Backend of torch.distributed.
port (int, optional): Master port. Defaults to None.
"""
proc_id = int(os.environ['SLURM_PROCID']) proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS']) ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST'] node_list = os.environ['SLURM_NODELIST']
...@@ -43,7 +53,14 @@ def _init_dist_slurm(backend, port=29500): ...@@ -43,7 +53,14 @@ def _init_dist_slurm(backend, port=29500):
torch.cuda.set_device(proc_id % num_gpus) torch.cuda.set_device(proc_id % num_gpus)
addr = subprocess.getoutput( addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1') f'scontrol show hostname {node_list} | head -n1')
os.environ['MASTER_PORT'] = str(port) # specify master port
if port is not None:
os.environ['MASTER_PORT'] = str(port)
elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable
else:
# 29500 is torch.distributed default port
os.environ['MASTER_PORT'] = '29500'
os.environ['MASTER_ADDR'] = addr os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks) os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['RANK'] = str(proc_id) os.environ['RANK'] = str(proc_id)
......
import os
from unittest.mock import patch
import pytest
from mmcv.runner import init_dist
@patch('torch.cuda.device_count', return_value=1)
@patch('torch.cuda.set_device')
@patch('torch.distributed.init_process_group')
@patch('subprocess.getoutput', return_value='127.0.0.1')
def test_init_dist(mock_getoutput, mock_dist_init, mock_set_device,
mock_device_count):
with pytest.raises(ValueError):
# launcher must be one of {'pytorch', 'mpi', 'slurm'}
init_dist('invaliad_launcher')
# test initialize with slurm launcher
os.environ['SLURM_PROCID'] = '0'
os.environ['SLURM_NTASKS'] = '1'
os.environ['SLURM_NODELIST'] = '[0]' # haven't check the correct form
init_dist('slurm')
# no port is specified, use default port 29500
assert os.environ['MASTER_PORT'] == '29500'
assert os.environ['MASTER_ADDR'] == '127.0.0.1'
assert os.environ['WORLD_SIZE'] == '1'
assert os.environ['RANK'] == '0'
mock_set_device.assert_called_with(0)
mock_getoutput.assert_called_with('scontrol show hostname [0] | head -n1')
mock_dist_init.assert_called_with(backend='nccl')
init_dist('slurm', port=29505)
# port is specified with argument 'port'
assert os.environ['MASTER_PORT'] == '29505'
assert os.environ['MASTER_ADDR'] == '127.0.0.1'
assert os.environ['WORLD_SIZE'] == '1'
assert os.environ['RANK'] == '0'
mock_set_device.assert_called_with(0)
mock_getoutput.assert_called_with('scontrol show hostname [0] | head -n1')
mock_dist_init.assert_called_with(backend='nccl')
init_dist('slurm')
# port is specified by environment variable 'MASTER_PORT'
assert os.environ['MASTER_PORT'] == '29505'
assert os.environ['MASTER_ADDR'] == '127.0.0.1'
assert os.environ['WORLD_SIZE'] == '1'
assert os.environ['RANK'] == '0'
mock_set_device.assert_called_with(0)
mock_getoutput.assert_called_with('scontrol show hostname [0] | head -n1')
mock_dist_init.assert_called_with(backend='nccl')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment