Unverified Commit 7ce6a2b5 authored by Shaoshuai Shi's avatar Shaoshuai Shi Committed by GitHub
Browse files

bugfixed: stuck when training with dist_train.sh, support tcp_port (#784)

parent 274c90c5
...@@ -161,9 +161,11 @@ def init_dist_slurm(tcp_port, local_rank, backend='nccl'): ...@@ -161,9 +161,11 @@ def init_dist_slurm(tcp_port, local_rank, backend='nccl'):
def init_dist_pytorch(tcp_port, local_rank, backend='nccl'): def init_dist_pytorch(tcp_port, local_rank, backend='nccl'):
if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn') mp.set_start_method('spawn')
os.environ['MASTER_PORT'] = str(tcp_port)
os.environ['MASTER_ADDR'] = 'localhost'
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
torch.cuda.set_device(local_rank % num_gpus) torch.cuda.set_device(local_rank % num_gpus)
dist.init_process_group( dist.init_process_group(
backend=backend, backend=backend,
# init_method='tcp://127.0.0.1:%d' % tcp_port, # init_method='tcp://127.0.0.1:%d' % tcp_port,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment