Unverified Commit cff3fecc authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

[Feature] Support finding free port in _init_dist_slurm() (#1846)



* [feat]:support find free port in _init_dist_slurm

* fix format

* Update mmcv/runner/dist_utils.py

should support port taken by a non-localhost address.
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Update dist_utils.py

Add Copyright.

* rename inner function

* Update mmcv/runner/dist_utils.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* fix dist_utils.py

change _is_port_in_use() criterion.

* Update dist_utils.py

rename _is_port_in_use to _is_free_port

* Update mmcv/runner/dist_utils.py
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update dist_utils.py

fix lint

* Update dist_utils.py

fix lint
Co-authored-by: default avatarMashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: default avatarZaida Zhou <58739961+zhouzaida@users.noreply.github.com>
parent c33f2489
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools import functools
import os import os
import socket
import subprocess import subprocess
from collections import OrderedDict from collections import OrderedDict
...@@ -11,6 +13,24 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, ...@@ -11,6 +13,24 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors) _unflatten_dense_tensors)
def _find_free_port():
# Copied from https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Binding to port 0 will cause the OS to find an available port for us
sock.bind(('', 0))
port = sock.getsockname()[1]
sock.close()
# NOTE: there is still a chance the port could be taken by other processes.
return port
def _is_free_port(port):
ips = socket.gethostbyname_ex(socket.gethostname())[-1]
ips.append('localhost')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return all(s.connect_ex((ip, port)) != 0 for ip in ips)
def init_dist(launcher, backend='nccl', **kwargs): def init_dist(launcher, backend='nccl', **kwargs):
if mp.get_start_method(allow_none=True) is None: if mp.get_start_method(allow_none=True) is None:
mp.set_start_method('spawn') mp.set_start_method('spawn')
...@@ -64,8 +84,12 @@ def _init_dist_slurm(backend, port=None): ...@@ -64,8 +84,12 @@ def _init_dist_slurm(backend, port=None):
elif 'MASTER_PORT' in os.environ: elif 'MASTER_PORT' in os.environ:
pass # use MASTER_PORT in the environment variable pass # use MASTER_PORT in the environment variable
else: else:
# 29500 is torch.distributed default port # if torch.distributed default port(29500) is available
# then use it, else find a free port
if _is_free_port(29500):
os.environ['MASTER_PORT'] = '29500' os.environ['MASTER_PORT'] = '29500'
else:
os.environ['MASTER_PORT'] = str(_find_free_port())
# use MASTER_ADDR in the environment variable if it already exists # use MASTER_ADDR in the environment variable if it already exists
if 'MASTER_ADDR' not in os.environ: if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr os.environ['MASTER_ADDR'] = addr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment