Unverified Commit 0e576575 authored by jmoralez's avatar jmoralez Committed by GitHub
Browse files

[dask] use random ports in network setup (#3823)

* use socket.bind with port 0 and client.run to find random open ports

* include test for found ports

* find random open ports as default

* parametrize local_listen_port. type hint to _find_random_open_port. fid open ports only on workers with data.

* make indentation consistent and pass list of workers to client.run

* remove socket import

* change random port implementation

* fix test
parent 7777852a
...@@ -45,83 +45,18 @@ def _get_dask_client(client: Optional[Client]) -> Client: ...@@ -45,83 +45,18 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client return client
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: def _find_random_open_port() -> int:
"""Find an open port. """Find a random open port on localhost.
This function tries to find a free port on the machine it's run on. It is intended to
be run once on each Dask worker, sequentially.
Parameters
----------
worker_ip : str
IP address for the Dask worker.
local_listen_port : int
First port to try when searching for open ports.
ports_to_skip: Iterable[int]
An iterable of integers referring to ports that should be skipped. Since multiple Dask
workers can run on the same physical machine, this method may be called multiple times
on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use
the same port for two worker processes running on the same machine.
Returns Returns
------- -------
port : int port : int
A free port on the machine referenced by ``worker_ip``. A free port on localhost
"""
max_tries = 1000
found_port = False
for i in range(max_tries):
out_port = local_listen_port + i
if out_port in ports_to_skip:
continue
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, out_port))
found_port = True
break
# if unavailable, you'll get OSError: Address already in use
except OSError:
continue
if not found_port:
msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value."
raise RuntimeError(msg % (worker_ip, local_listen_port, out_port))
return out_port
def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
"""Find an open port on each worker.
LightGBM distributed training uses TCP sockets by default, and this method is used to
identify open ports on each worker so LightGBM can reliable create those sockets.
Parameters
----------
client : dask.distributed.Client
Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``.
local_listen_port : int
First port to try when searching for open ports.
Returns
-------
result : Dict[str, int]
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
""" """
lightgbm_ports: Set[int] = set() with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
worker_ip_to_port = {} s.bind(('', 0))
for worker_address in worker_addresses: port = s.getsockname()[1]
port = client.submit( return port
func=_find_open_port,
workers=[worker_address],
worker_ip=urlparse(worker_address).hostname,
local_listen_port=local_listen_port,
ports_to_skip=lightgbm_ports
).result()
lightgbm_ports.add(port)
worker_ip_to_port[worker_address] = port
return worker_ip_to_port
def _concat(seq: List[_DaskPart]) -> _DaskPart: def _concat(seq: List[_DaskPart]) -> _DaskPart:
...@@ -415,10 +350,9 @@ def _train( ...@@ -415,10 +350,9 @@ def _train(
} }
else: else:
_log_info("Finding random open ports for workers") _log_info("Finding random open ports for workers")
worker_address_to_port = _find_ports_for_workers( worker_address_to_port = client.run(
client=client, _find_random_open_port,
worker_addresses=worker_addresses, workers=list(worker_addresses)
local_listen_port=local_listen_port
) )
machines = ','.join([ machines = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port) '%s:%d' % (urlparse(worker_address).hostname, port)
......
...@@ -174,14 +174,6 @@ def _accuracy_score(dy_true, dy_pred): ...@@ -174,14 +174,6 @@ def _accuracy_score(dy_true, dy_pred):
return da.average(dy_true == dy_pred).compute() return da.average(dy_true == dy_pred).compute()
def _find_random_open_port() -> int:
"""Find a random open port on localhost"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', 0))
port = s.getsockname()[1]
return port
def _pickle(obj, filepath, serializer): def _pickle(obj, filepath, serializer):
if serializer == 'pickle': if serializer == 'pickle':
with open(filepath, 'wb') as f: with open(filepath, 'wb') as f:
...@@ -343,6 +335,19 @@ def test_classifier_pred_contrib(output, centers, client): ...@@ -343,6 +335,19 @@ def test_classifier_pred_contrib(output, centers, client):
client.close(timeout=CLIENT_CLOSE_TIMEOUT) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_find_random_open_port(client):
for _ in range(5):
worker_address_to_port = client.run(lgb.dask._find_random_open_port)
found_ports = worker_address_to_port.values()
# check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open
for port in found_ports:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port))
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_training_does_not_fail_on_port_conflicts(client): def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array') _, _, _, dX, dy, dw = _create_data('classification', output='array')
...@@ -885,29 +890,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -885,29 +890,6 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert_eq(preds_orig_local, preds_loaded_model_local) assert_eq(preds_orig_local, preds_loaded_model_local)
def test_find_open_port_works(listen_port):
worker_ip = '127.0.0.1'
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, listen_port))
new_port = lgb.dask._find_open_port(
worker_ip=worker_ip,
local_listen_port=listen_port,
ports_to_skip=set()
)
assert listen_port < new_port < listen_port + 1000
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_1:
s_1.bind((worker_ip, listen_port))
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_2:
s_2.bind((worker_ip, listen_port + 1))
new_port = lgb.dask._find_open_port(
worker_ip=worker_ip,
local_listen_port=listen_port,
ports_to_skip=set()
)
assert listen_port + 1 < new_port < listen_port + 1000
def test_warns_and_continues_on_unrecognized_tree_learner(client): def test_warns_and_continues_on_unrecognized_tree_learner(client):
X = da.random.random((1e3, 10)) X = da.random.random((1e3, 10))
y = da.random.random((1e3, 1)) y = da.random.random((1e3, 1))
...@@ -1075,7 +1057,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output ...@@ -1075,7 +1057,7 @@ def test_network_params_not_required_but_respected_if_given(client, task, output
# model 2 - machines given # model 2 - machines given
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
open_ports = [_find_random_open_port() for _ in range(n_workers)] open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model2 = dask_model_factory( dask_model2 = dask_model_factory(
n_estimators=5, n_estimators=5,
num_leaves=5, num_leaves=5,
...@@ -1143,7 +1125,7 @@ def test_machines_should_be_used_if_provided(task, output): ...@@ -1143,7 +1125,7 @@ def test_machines_should_be_used_if_provided(task, output):
client.rebalance() client.rebalance()
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
open_ports = [_find_random_open_port() for _ in range(n_workers)] open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model = dask_model_factory( dask_model = dask_model_factory(
n_estimators=5, n_estimators=5,
num_leaves=5, num_leaves=5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment