"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "36732f23baa9e87617ea7700a9f59d7f53e313c6"
Unverified Commit 1ce4b22b authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] make random port search more resilient to random collisions (fixes #4057) (#4133)

* [dask] make random port search more resilient to random collisions

* linting

* more reliable ports check

* address review comments

* add error message
parent 9388b2ec
...@@ -170,6 +170,44 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[ ...@@ -170,6 +170,44 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
return out return out
def _possibly_fix_worker_map_duplicates(worker_map: Dict[str, int], client: Client) -> Dict[str, int]:
"""Fix any duplicate IP-port pairs in a ``worker_map``."""
worker_map = deepcopy(worker_map)
workers_that_need_new_ports = []
host_to_port = defaultdict(set)
for worker, port in worker_map.items():
host = urlparse(worker).hostname
if port in host_to_port[host]:
workers_that_need_new_ports.append(worker)
else:
host_to_port[host].add(port)
# if any duplicates were found, search for new ports one by one
for worker in workers_that_need_new_ports:
_log_info(f"Searching for a LightGBM training port for worker '{worker}'")
host = urlparse(worker).hostname
retries_remaining = 100
while retries_remaining > 0:
retries_remaining -= 1
new_port = client.submit(
_find_random_open_port,
workers=[worker],
allow_other_workers=False,
pure=False
).result()
if new_port not in host_to_port[host]:
worker_map[worker] = new_port
host_to_port[host].add(new_port)
break
if retries_remaining == 0:
raise LightGBMError(
"Failed to find an open port. Try re-running training or explicitly setting 'machines' or 'local_listen_port'."
)
return worker_map
def _train( def _train(
client: Client, client: Client,
data: _DaskMatrixLike, data: _DaskMatrixLike,
...@@ -367,10 +405,19 @@ def _train( ...@@ -367,10 +405,19 @@ def _train(
} }
else: else:
_log_info("Finding random open ports for workers") _log_info("Finding random open ports for workers")
# this approach with client.run() is faster than searching for ports
# serially, but can produce duplicates sometimes. Try the fast approach one
# time, then pass it through a function that will use a slower but more reliable
# approach if duplicates are found.
worker_address_to_port = client.run( worker_address_to_port = client.run(
_find_random_open_port, _find_random_open_port,
workers=list(worker_addresses) workers=list(worker_addresses)
) )
worker_address_to_port = _possibly_fix_worker_map_duplicates(
worker_map=worker_address_to_port,
client=client
)
machines = ','.join([ machines = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port) '%s:%d' % (urlparse(worker_address).hostname, port)
for worker_address, port for worker_address, port
......
...@@ -392,6 +392,37 @@ def test_find_random_open_port(client): ...@@ -392,6 +392,37 @@ def test_find_random_open_port(client):
client.close(timeout=CLIENT_CLOSE_TIMEOUT) client.close(timeout=CLIENT_CLOSE_TIMEOUT)
def test_possibly_fix_worker_map(capsys, client):
client.wait_for_workers(2)
worker_addresses = list(client.scheduler_info()["workers"].keys())
retry_msg = 'Searching for a LightGBM training port for worker'
# should handle worker maps without any duplicates
map_without_duplicates = {
worker_address: 12400 + i
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_without_duplicates
)
assert patched_map == map_without_duplicates
assert retry_msg not in capsys.readouterr().out
# should handle worker maps with duplicates
map_with_duplicates = {
worker_address: 12400
for i, worker_address in enumerate(worker_addresses)
}
patched_map = lgb.dask._possibly_fix_worker_map_duplicates(
client=client,
worker_map=map_with_duplicates
)
assert retry_msg in capsys.readouterr().out
assert len(set(patched_map.values())) == len(worker_addresses)
def test_training_does_not_fail_on_port_conflicts(client): def test_training_does_not_fail_on_port_conflicts(client):
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array') _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment