Unverified Commit 5857ef5e authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[tests][dask] Use workers hostname in tests (fixes #4594) (#4595)


Co-authored-by: default avatarNikita Titov <nekit94-12@hotmail.com>
parent d411bced
...@@ -9,6 +9,7 @@ from itertools import groupby ...@@ -9,6 +9,7 @@ from itertools import groupby
from os import getenv from os import getenv
from platform import machine from platform import machine
from sys import platform from sys import platform
from urllib.parse import urlparse
import pytest import pytest
...@@ -87,6 +88,11 @@ def listen_port(): ...@@ -87,6 +88,11 @@ def listen_port():
listen_port.port = 13000 listen_port.port = 13000
def _get_workers_hostname(cluster: LocalCluster) -> str:
one_worker_address = next(iter(cluster.scheduler_info['workers']))
return urlparse(one_worker_address).hostname
def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs): def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs):
X, y, g = make_ranking(n_samples=n_samples, random_state=42, **kwargs) X, y, g = make_ranking(n_samples=n_samples, random_state=42, **kwargs)
rnd = np.random.RandomState(42) rnd = np.random.RandomState(42)
...@@ -485,8 +491,9 @@ def test_training_does_not_fail_on_port_conflicts(cluster): ...@@ -485,8 +491,9 @@ def test_training_does_not_fail_on_port_conflicts(cluster):
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array') _, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
lightgbm_default_port = 12400 lightgbm_default_port = 12400
workers_hostname = _get_workers_hostname(cluster)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', lightgbm_default_port)) s.bind((workers_hostname, lightgbm_default_port))
dask_classifier = lgb.DaskLGBMClassifier( dask_classifier = lgb.DaskLGBMClassifier(
client=client, client=client,
time_out=5, time_out=5,
...@@ -1395,13 +1402,14 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c ...@@ -1395,13 +1402,14 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert 'machines' not in params assert 'machines' not in params
# model 2 - machines given # model 2 - machines given
workers_hostname = _get_workers_hostname(cluster)
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
open_ports = lgb.dask._find_n_open_ports(n_workers) open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model2 = dask_model_factory( dask_model2 = dask_model_factory(
n_estimators=5, n_estimators=5,
num_leaves=5, num_leaves=5,
machines=",".join([ machines=",".join([
f"127.0.0.1:{port}" f"{workers_hostname}:{port}"
for port in open_ports for port in open_ports
]), ]),
) )
...@@ -1442,12 +1450,13 @@ def test_machines_should_be_used_if_provided(task, cluster): ...@@ -1442,12 +1450,13 @@ def test_machines_should_be_used_if_provided(task, cluster):
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1 assert n_workers > 1
workers_hostname = _get_workers_hostname(cluster)
open_ports = lgb.dask._find_n_open_ports(n_workers) open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model = dask_model_factory( dask_model = dask_model_factory(
n_estimators=5, n_estimators=5,
num_leaves=5, num_leaves=5,
machines=",".join([ machines=",".join([
f"127.0.0.1:{port}" f"{workers_hostname}:{port}"
for port in open_ports for port in open_ports
]), ]),
) )
...@@ -1457,7 +1466,7 @@ def test_machines_should_be_used_if_provided(task, cluster): ...@@ -1457,7 +1466,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
error_msg = f"Binding port {open_ports[0]} failed" error_msg = f"Binding port {open_ports[0]} failed"
with pytest.raises(lgb.basic.LightGBMError, match=error_msg): with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', open_ports[0])) s.bind((workers_hostname, open_ports[0]))
dask_model.fit(dX, dy, group=dg) dask_model.fit(dX, dy, group=dg)
# The above error leaves a worker waiting # The above error leaves a worker waiting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment