Unverified Commit ac57d5a4 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[dask] hold ports until training (#5890)

parent 07e3cf47
......@@ -144,7 +144,7 @@ try:
from dask.bag import from_delayed as dask_bag_from_delayed
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, default_client, wait
from dask.distributed import Client, Future, default_client, wait
DASK_INSTALLED = True
except ImportError:
DASK_INSTALLED = False
......@@ -161,6 +161,12 @@ except ImportError:
def __init__(self, *args, **kwargs):
pass
class Future: # type: ignore
"""Dummy class for dask.distributed.Future."""
def __init__(self, *args, **kwargs):
pass
class dask_Array: # type: ignore
"""Dummy class for dask.array.Array."""
......
......@@ -6,6 +6,7 @@ dask.Array and dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import operator
import socket
from collections import defaultdict
from copy import deepcopy
......@@ -18,7 +19,7 @@ import numpy as np
import scipy.sparse as ss
from .basic import LightGBMError, _choose_param_value, _ConfigAliases, _log_info, _log_warning
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, Future, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomObjectiveFunction,
......@@ -38,18 +39,21 @@ _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
class _HostWorkers:
class _RemoteSocket:
def acquire(self) -> int:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.socket.bind(('', 0))
return self.socket.getsockname()[1]
def __init__(self, default: str, all_workers: List[str]):
self.default = default
self.all_workers = all_workers
def release(self) -> None:
self.socket.close()
def __eq__(self, other: object) -> bool:
return (
isinstance(other, type(self))
and self.default == other.default
and self.all_workers == other.all_workers
)
def _acquire_port() -> Tuple[_RemoteSocket, int]:
s = _RemoteSocket()
port = s.acquire()
return s, port
class _DatasetNames(Enum):
......@@ -83,73 +87,40 @@ def _get_dask_client(client: Optional[Client]) -> Client:
return client
def _find_n_open_ports(n: int) -> List[int]:
"""Find n random open ports on localhost.
Returns
-------
ports : list of int
n random open ports on localhost.
"""
sockets = []
for _ in range(n):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('', 0))
sockets.append(s)
ports = []
for s in sockets:
ports.append(s.getsockname()[1])
s.close()
return ports
def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWorkers]:
"""Group all worker addresses by hostname.
Returns
-------
host_to_workers : dict
mapping from hostname to all its workers.
"""
host_to_workers: Dict[str, _HostWorkers] = {}
for address in worker_addresses:
hostname = urlparse(address).hostname
if not hostname:
raise ValueError(f"Could not parse host name from worker address '{address}'")
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all_workers=[address])
else:
host_to_workers[hostname].all_workers.append(address)
return host_to_workers
def _assign_open_ports_to_workers(
client: Client,
host_to_workers: Dict[str, _HostWorkers]
) -> Dict[str, int]:
workers: List[str],
) -> Tuple[Dict[str, Future], Dict[str, int]]:
"""Assign an open port to each worker.
Returns
-------
worker_to_socket_future: dict
mapping from worker address to a future pointing to the remote socket.
worker_to_port: dict
mapping from worker address to an open port.
mapping from worker address to an open port in the worker's host.
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all_workers)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
workers=[workers.default],
pure=False,
# Acquire port in worker
worker_to_future = {}
for worker in workers:
worker_to_future[worker] = client.submit(
_acquire_port,
workers=[worker],
allow_other_workers=False,
pure=False,
)
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all_workers, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port
# schedule futures to retrieve each element of the tuple
worker_to_socket_future = {}
worker_to_port_future = {}
for worker, socket_future in worker_to_future.items():
worker_to_socket_future[worker] = client.submit(operator.itemgetter(0), socket_future)
worker_to_port_future[worker] = client.submit(operator.itemgetter(1), socket_future)
# retrieve ports
worker_to_port = client.gather(worker_to_port_future)
return worker_to_socket_future, worker_to_port
def _concat(seq: List[_DaskPart]) -> _DaskPart:
......@@ -190,6 +161,7 @@ def _train_part(
num_machines: int,
return_model: bool,
time_out: int,
remote_socket: _RemoteSocket,
**kwargs: Any
) -> Optional[LGBMModel]:
network_params = {
......@@ -320,6 +292,8 @@ def _train_part(
kwargs['eval_class_weight'] = [eval_class_weight[i] for i in eval_component_idx]
model = model_factory(**params)
if remote_socket is not None:
remote_socket.release()
try:
if is_ranker:
model.fit(
......@@ -777,6 +751,7 @@ def _train(
machines = params.pop("machines")
# figure out network params
worker_to_socket_future: Dict[str, Future] = {}
worker_addresses = worker_map.keys()
if machines is not None:
_log_info("Using passed-in 'machines' parameter")
......@@ -802,8 +777,7 @@ def _train(
}
else:
_log_info("Finding random open ports for workers")
host_to_workers = _group_workers_by_host(worker_map.keys())
worker_address_to_port = _assign_open_ports_to_workers(client, host_to_workers)
worker_to_socket_future, worker_address_to_port = _assign_open_ports_to_workers(client, list(worker_map.keys()))
machines = ','.join([
f'{urlparse(worker_address).hostname}:{port}'
......@@ -831,6 +805,7 @@ def _train(
local_listen_port=worker_address_to_port[worker],
num_machines=num_machines,
time_out=params.get('time_out', 120),
remote_socket=worker_to_socket_future.get(worker, None),
return_model=(worker == master_worker),
workers=[worker],
allow_other_workers=False,
......
......@@ -519,26 +519,6 @@ def test_classifier_custom_objective(output, task, cluster):
assert_eq(p1_proba, p1_proba_local)
def test_group_workers_by_host():
hosts = [f'0.0.0.{i}' for i in range(2)]
workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts]
expected = {
host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0',
all_workers=[f'tcp://{host}:0', f'tcp://{host}:1']
)
for host in hosts
}
host_to_workers = lgb.dask._group_workers_by_host(workers)
assert host_to_workers == expected
def test_group_workers_by_host_unparseable_host_names():
workers_without_protocol = ['0.0.0.1:80', '0.0.0.2:80']
with pytest.raises(ValueError, match="Could not parse host name from worker address '0.0.0.1:80'"):
lgb.dask._group_workers_by_host(workers_without_protocol)
def test_machines_to_worker_map_unparseable_host_names():
workers = {'0.0.0.1:80': {}, '0.0.0.2:80': {}}
machines = "0.0.0.1:80,0.0.0.2:80"
......@@ -546,23 +526,6 @@ def test_machines_to_worker_map_unparseable_host_names():
lgb.dask._machines_to_worker_map(machines=machines, worker_addresses=workers.keys())
def test_assign_open_ports_to_workers(cluster):
with Client(cluster) as client:
workers = client.scheduler_info()['workers'].keys()
n_workers = len(workers)
host_to_workers = lgb.dask._group_workers_by_host(workers)
for _ in range(25):
worker_address_to_port = lgb.dask._assign_open_ports_to_workers(client, host_to_workers)
found_ports = worker_address_to_port.values()
assert len(found_ports) == n_workers
# check that found ports are different for same address (LocalCluster)
assert len(set(found_ports)) == len(found_ports)
# check that the ports are indeed open
for port in found_ports:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('', port))
def test_training_does_not_fail_on_port_conflicts(cluster):
with Client(cluster) as client:
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
......@@ -1588,15 +1551,17 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert 'machines' not in params
# model 2 - machines given
workers = list(client.scheduler_info()['workers'])
workers_hostname = _get_workers_hostname(cluster)
n_workers = len(client.scheduler_info()['workers'])
open_ports = lgb.dask._find_n_open_ports(n_workers)
remote_sockets, open_ports = lgb.dask._assign_open_ports_to_workers(client, workers)
for s in remote_sockets.values():
s.release()
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
machines=",".join([
f"{workers_hostname}:{port}"
for port in open_ports
for port in open_ports.values()
]),
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment