Unverified Commit 5fe27d59 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[dask] find all needed ports in each host at once (fixes #4458) (#4498)

* find all needed ports in each worker at once

* lint

* better naming

* use _HostWorkers in test
parent 1dbf4382
...@@ -7,7 +7,7 @@ dask.Array and dask.DataFrame collections. ...@@ -7,7 +7,7 @@ dask.Array and dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost. It is based on dask-lightgbm, which was based on dask-xgboost.
""" """
import socket import socket
from collections import defaultdict from collections import defaultdict, namedtuple
from copy import deepcopy from copy import deepcopy
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
...@@ -30,6 +30,8 @@ _DaskVectorLike = Union[dask_Array, dask_Series] ...@@ -30,6 +30,8 @@ _DaskVectorLike = Union[dask_Array, dask_Series]
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix] _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
_HostWorkers = namedtuple('HostWorkers', ['default', 'all'])
class _DatasetNames(Enum): class _DatasetNames(Enum):
"""Placeholder names used by lightgbm.dask internals to say 'also evaluate the training data'. """Placeholder names used by lightgbm.dask internals to say 'also evaluate the training data'.
...@@ -62,18 +64,71 @@ def _get_dask_client(client: Optional[Client]) -> Client: ...@@ -62,18 +64,71 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client return client
def _find_random_open_port() -> int: def _find_n_open_ports(n: int) -> List[int]:
"""Find a random open port on localhost. """Find n random open ports on localhost.
Returns Returns
------- -------
port : int ports : list of int
A free port on localhost n random open ports on localhost.
""" """
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: sockets = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0)) s.bind(('', 0))
port = s.getsockname()[1] sockets.append(s)
return port ports = []
for s in sockets:
ports.append(s.getsockname()[1])
s.close()
return ports
def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWorkers]:
"""Group all worker addresses by hostname.
Returns
-------
host_to_workers : dict
mapping from hostname to all its workers.
"""
host_to_workers: Dict[str, _HostWorkers] = {}
for address in worker_addresses:
hostname = urlparse(address).hostname
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all=[address])
else:
host_to_workers[hostname].all.append(address)
return host_to_workers
def _assign_open_ports_to_workers(
client: Client,
host_to_workers: Dict[str, _HostWorkers]
) -> Dict[str, int]:
"""Assign an open port to each worker.
Returns
-------
worker_to_port: dict
mapping from worker address to an open port.
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
workers=[workers.default],
pure=False,
allow_other_workers=False,
)
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port
def _concat(seq: List[_DaskPart]) -> _DaskPart: def _concat(seq: List[_DaskPart]) -> _DaskPart:
...@@ -330,44 +385,6 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[ ...@@ -330,44 +385,6 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
return out return out
def _possibly_fix_worker_map_duplicates(worker_map: Dict[str, int], client: Client) -> Dict[str, int]:
"""Fix any duplicate IP-port pairs in a ``worker_map``."""
worker_map = deepcopy(worker_map)
workers_that_need_new_ports = []
host_to_port = defaultdict(set)
for worker, port in worker_map.items():
host = urlparse(worker).hostname
if port in host_to_port[host]:
workers_that_need_new_ports.append(worker)
else:
host_to_port[host].add(port)
# if any duplicates were found, search for new ports one by one
for worker in workers_that_need_new_ports:
_log_info(f"Searching for a LightGBM training port for worker '{worker}'")
host = urlparse(worker).hostname
retries_remaining = 100
while retries_remaining > 0:
retries_remaining -= 1
new_port = client.submit(
_find_random_open_port,
workers=[worker],
allow_other_workers=False,
pure=False
).result()
if new_port not in host_to_port[host]:
worker_map[worker] = new_port
host_to_port[host].add(new_port)
break
if retries_remaining == 0:
raise LightGBMError(
"Failed to find an open port. Try re-running training or explicitly setting 'machines' or 'local_listen_port'."
)
return worker_map
def _train( def _train(
client: Client, client: Client,
data: _DaskMatrixLike, data: _DaskMatrixLike,
...@@ -726,18 +743,8 @@ def _train( ...@@ -726,18 +743,8 @@ def _train(
} }
else: else:
_log_info("Finding random open ports for workers") _log_info("Finding random open ports for workers")
# this approach with client.run() is faster than searching for ports host_to_workers = _group_workers_by_host(worker_map.keys())
# serially, but can produce duplicates sometimes. Try the fast approach one worker_address_to_port = _assign_open_ports_to_workers(client, host_to_workers)
# time, then pass it through a function that will use a slower but more reliable
# approach if duplicates are found.
worker_address_to_port = client.run(
_find_random_open_port,
workers=list(worker_addresses)
)
worker_address_to_port = _possibly_fix_worker_map_duplicates(
worker_map=worker_address_to_port,
client=client
)
machines = ','.join([ machines = ','.join([
f'{urlparse(worker_address).hostname}:{port}' f'{urlparse(worker_address).hostname}:{port}'
......
...@@ -446,11 +446,29 @@ def test_classifier_pred_contrib(output, task, cluster): ...@@ -446,11 +446,29 @@ def test_classifier_pred_contrib(output, task, cluster):
assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1) assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1)
def test_find_random_open_port(cluster): def test_group_workers_by_host():
hosts = [f'0.0.0.{i}' for i in range(2)]
workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts]
expected = {
host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0',
all=[f'tcp://{host}:0', f'tcp://{host}:1']
)
for host in hosts
}
host_to_workers = lgb.dask._group_workers_by_host(workers)
assert host_to_workers == expected
def test_assign_open_ports_to_workers(cluster):
with Client(cluster) as client: with Client(cluster) as client:
for _ in range(5): workers = client.scheduler_info()['workers'].keys()
worker_address_to_port = client.run(lgb.dask._find_random_open_port) n_workers = len(workers)
host_to_workers = lgb.dask._group_workers_by_host(workers)
for _ in range(1_000):
worker_address_to_port = lgb.dask._assign_open_ports_to_workers(client, host_to_workers)
found_ports = worker_address_to_port.values() found_ports = worker_address_to_port.values()
assert len(found_ports) == n_workers
# check that found ports are different for same address (LocalCluster) # check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports) assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open # check that the ports are indeed open
...@@ -459,37 +477,6 @@ def test_find_random_open_port(cluster): ...@@ -459,37 +477,6 @@ def test_find_random_open_port(cluster):
s.bind(('', port)) s.bind(('', port))
def test_possibly_fix_worker_map(capsys, cluster):
with Client(cluster) as client:
worker_addresses = list(client.scheduler_info()["workers"].keys())
retry_msg = 'Searching for a LightGBM training port for worker'
# should handle worker maps without any duplicates
map_without_duplicates = {
worker_address: 12400 + i
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_without_duplicates
)
assert patched_map == map_without_duplicates
assert retry_msg not in capsys.readouterr().out
# should handle worker maps with duplicates
map_with_duplicates = {
worker_address: 12400
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_with_duplicates
)
assert retry_msg in capsys.readouterr().out
assert len(set(patched_map.values())) == len(worker_addresses)
def test_training_does_not_fail_on_port_conflicts(cluster): def test_training_does_not_fail_on_port_conflicts(cluster):
with Client(cluster) as client: with Client(cluster) as client:
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array') _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
...@@ -1406,7 +1393,7 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c ...@@ -1406,7 +1393,7 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
# model 2 - machines given # model 2 - machines given
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)] open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model2 = dask_model_factory( dask_model2 = dask_model_factory(
n_estimators=5, n_estimators=5,
num_leaves=5, num_leaves=5,
...@@ -1452,7 +1439,7 @@ def test_machines_should_be_used_if_provided(task, cluster): ...@@ -1452,7 +1439,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1 assert n_workers > 1
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)] open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model = dask_model_factory( dask_model = dask_model_factory(
n_estimators=5, n_estimators=5,
num_leaves=5, num_leaves=5,
...@@ -1474,7 +1461,7 @@ def test_machines_should_be_used_if_provided(task, cluster): ...@@ -1474,7 +1461,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
client.restart() client.restart()
# an informative error should be raised if "machines" has duplicates # an informative error should be raised if "machines" has duplicates
one_open_port = lgb.dask._find_random_open_port() one_open_port = lgb.dask._find_n_open_ports(1)
dask_model.set_params( dask_model.set_params(
machines=",".join([ machines=",".join([
f"127.0.0.1:{one_open_port}" f"127.0.0.1:{one_open_port}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment