Unverified Commit f6d2dce4 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] [python-package] Search for available ports when setting up network (fixes #3753) (#3766)



* starting work

* fixed port-binding issue on localhost

* minor cleanup

* updates

* getting closer

* definitely working for LocalCluster

* it works, it works

* docs

* add tests

* removing testing-only files

* linting

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* remove duplicated code

* remove unnecessary listen()
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 9bacf03c
...@@ -5,7 +5,9 @@ This module enables you to perform distributed training with LightGBM on Dask.Ar ...@@ -5,7 +5,9 @@ This module enables you to perform distributed training with LightGBM on Dask.Ar
It is based on dask-xgboost package. It is based on dask-xgboost package.
""" """
import logging import logging
import socket
from collections import defaultdict from collections import defaultdict
from typing import Dict, Iterable
from urllib.parse import urlparse from urllib.parse import urlparse
import numpy as np import numpy as np
...@@ -13,7 +15,7 @@ import pandas as pd ...@@ -13,7 +15,7 @@ import pandas as pd
from dask import array as da from dask import array as da
from dask import dataframe as dd from dask import dataframe as dd
from dask import delayed from dask import delayed
from dask.distributed import default_client, get_worker, wait from dask.distributed import Client, default_client, get_worker, wait
from .basic import _LIB, _safe_call from .basic import _LIB, _safe_call
from .sklearn import LGBMClassifier, LGBMRegressor from .sklearn import LGBMClassifier, LGBMRegressor
...@@ -23,33 +25,84 @@ import scipy.sparse as ss ...@@ -23,33 +25,84 @@ import scipy.sparse as ss
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _parse_host_port(address): def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
parsed = urlparse(address) """Find an open port.
return parsed.hostname, parsed.port
This function tries to find a free port on the machine it's run on. It is intended to
be run once on each Dask worker, sequentially.
def _build_network_params(worker_addresses, local_worker_ip, local_listen_port, time_out): Parameters
"""Build network parameters suitable for LightGBM C backend. ----------
worker_ip : str
IP address for the Dask worker.
local_listen_port : int
First port to try when searching for open ports.
ports_to_skip: Iterable[int]
An iterable of integers referring to ports that should be skipped. Since multiple Dask
workers can run on the same physical machine, this method may be called multiple times
on the same machine. ``ports_to_skip`` is used to ensure that LightGBM doesn't try to use
the same port for two worker processes running on the same machine.
Returns
-------
result : int
A free port on the machine referenced by ``worker_ip``.
"""
max_tries = 1000
out_port = None
found_port = False
for i in range(max_tries):
out_port = local_listen_port + i
if out_port in ports_to_skip:
continue
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind((worker_ip, out_port))
found_port = True
break
# if unavailable, you'll get OSError: Address already in use
except OSError:
continue
if not found_port:
msg = "LightGBM tried %s:%d-%d and could not create a connection. Try setting local_listen_port to a different value."
raise RuntimeError(msg % (worker_ip, local_listen_port, out_port))
return out_port
def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], local_listen_port: int) -> Dict[str, int]:
"""Find an open port on each worker.
LightGBM distributed training uses TCP sockets by default, and this method is used to
identify open ports on each worker so LightGBM can reliable create those sockets.
Parameters Parameters
---------- ----------
worker_addresses : iterable of str - collection of worker addresses in `<protocol>://<host>:port` format client : dask.distributed.Client
local_worker_ip : str Dask client.
worker_addresses : Iterable[str]
An iterable of addresses for workers in the cluster. These are strings of the form ``<protocol>://<host>:port``
local_listen_port : int local_listen_port : int
time_out : int First port to try when searching for open ports.
Returns Returns
------- -------
params: dict result : Dict[str, int]
Dictionary where keys are worker addresses and values are an open port for LightGBM to use.
""" """
addr_port_map = {addr: (local_listen_port + i) for i, addr in enumerate(worker_addresses)} lightgbm_ports = set()
params = { worker_ip_to_port = {}
'machines': ','.join('%s:%d' % (_parse_host_port(addr)[0], port) for addr, port in addr_port_map.items()), for worker_address in worker_addresses:
'local_listen_port': addr_port_map[local_worker_ip], port = client.submit(
'time_out': time_out, func=_find_open_port,
'num_machines': len(addr_port_map) workers=[worker_address],
} worker_ip=urlparse(worker_address).hostname,
return params local_listen_port=local_listen_port,
ports_to_skip=lightgbm_ports
).result()
lightgbm_ports.add(port)
worker_ip_to_port[worker_address] = port
return worker_ip_to_port
def _concat(seq): def _concat(seq):
...@@ -63,9 +116,20 @@ def _concat(seq): ...@@ -63,9 +116,20 @@ def _concat(seq):
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0]))) raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))
def _train_part(params, model_factory, list_of_parts, worker_addresses, return_model, local_listen_port=12400, def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model,
time_out=120, **kwargs): time_out=120, **kwargs):
network_params = _build_network_params(worker_addresses, get_worker().address, local_listen_port, time_out) local_worker_address = get_worker().address
machine_list = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
for worker_address, port
in worker_address_to_port.items()
])
network_params = {
'machines': machine_list,
'local_listen_port': worker_address_to_port[local_worker_address],
'time_out': time_out,
'num_machines': len(worker_address_to_port)
}
params.update(network_params) params.update(network_params)
# Concatenate many parts into one # Concatenate many parts into one
...@@ -138,13 +202,22 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs): ...@@ -138,13 +202,22 @@ def _train(client, data, label, params, model_factory, weight=None, **kwargs):
'(%s), using "data" as default', params.get("tree_learner", None)) '(%s), using "data" as default', params.get("tree_learner", None))
params['tree_learner'] = 'data' params['tree_learner'] = 'data'
# find an open port on each worker. note that multiple workers can run
# on the same machine, so this needs to ensure that each one gets its
# own port
local_listen_port = params.get('local_listen_port', 12400)
worker_address_to_port = _find_ports_for_workers(
client=client,
worker_addresses=worker_map.keys(),
local_listen_port=local_listen_port
)
# Tell each worker to train on the parts that it has locally # Tell each worker to train on the parts that it has locally
futures_classifiers = [client.submit(_train_part, futures_classifiers = [client.submit(_train_part,
model_factory=model_factory, model_factory=model_factory,
params={**params, 'num_threads': worker_ncores[worker]}, params={**params, 'num_threads': worker_ncores[worker]},
list_of_parts=list_of_parts, list_of_parts=list_of_parts,
worker_addresses=list(worker_map.keys()), worker_address_to_port=worker_address_to_port,
local_listen_port=params.get('local_listen_port', 12400),
time_out=params.get('time_out', 120), time_out=params.get('time_out', 120),
return_model=(worker == master_worker), return_model=(worker == master_worker),
**kwargs) **kwargs)
......
# coding: utf-8 # coding: utf-8
import os import os
import socket
import sys import sys
import pytest import pytest
...@@ -89,6 +90,26 @@ def test_classifier(output, centers, client, listen_port): ...@@ -89,6 +90,26 @@ def test_classifier(output, centers, client, listen_port):
assert_eq(y, p2) assert_eq(y, p2)
def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, dX, dy, dw = _create_data('classification', output='array')
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', 12400))
dask_classifier = dlgbm.DaskLGBMClassifier(
time_out=5,
local_listen_port=12400
)
for i in range(5):
dask_classifier.fit(
X=dX,
y=dy,
sample_weight=dw,
client=client
)
assert dask_classifier.booster_
@pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('centers', data_centers) @pytest.mark.parametrize('centers', data_centers)
def test_classifier_proba(output, centers, client, listen_port): def test_classifier_proba(output, centers, client, listen_port):
...@@ -183,21 +204,27 @@ def test_regressor_local_predict(client, listen_port): ...@@ -183,21 +204,27 @@ def test_regressor_local_predict(client, listen_port):
assert_eq(s1, s2) assert_eq(s1, s2)
def test_build_network_params(): def test_find_open_port_works():
workers_ips = [ worker_ip = '127.0.0.1'
'tcp://192.168.0.1:34545', with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
'tcp://192.168.0.2:34346', s.bind((worker_ip, 12400))
'tcp://192.168.0.3:34347' new_port = dlgbm._find_open_port(
] worker_ip=worker_ip,
local_listen_port=12400,
params = dlgbm._build_network_params(workers_ips, 'tcp://192.168.0.2:34346', 12400, 120) ports_to_skip=set()
exp_params = { )
'machines': '192.168.0.1:12400,192.168.0.2:12401,192.168.0.3:12402', assert new_port == 12401
'local_listen_port': 12401,
'num_machines': len(workers_ips), with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_1:
'time_out': 120 s_1.bind((worker_ip, 12400))
} with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s_2:
assert exp_params == params s_2.bind((worker_ip, 12401))
new_port = dlgbm._find_open_port(
worker_ip=worker_ip,
local_listen_port=12400,
ports_to_skip=set()
)
assert new_port == 12402
@gen_cluster(client=True, timeout=None) @gen_cluster(client=True, timeout=None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment