"tests/vscode:/vscode.git/clone" did not exist on "163416d2f56b6fe31e72138b8046b9e018c5e27c"
Unverified Commit 5857ef5e authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[tests][dask] Use workers hostname in tests (fixes #4594) (#4595)


Co-authored-by: default avatarNikita Titov <nekit94-12@hotmail.com>
parent d411bced
......@@ -9,6 +9,7 @@ from itertools import groupby
from os import getenv
from platform import machine
from sys import platform
from urllib.parse import urlparse
import pytest
......@@ -87,6 +88,11 @@ def listen_port():
listen_port.port = 13000
def _get_workers_hostname(cluster: LocalCluster) -> str:
one_worker_address = next(iter(cluster.scheduler_info['workers']))
return urlparse(one_worker_address).hostname
def _create_ranking_data(n_samples=100, output='array', chunk_size=50, **kwargs):
X, y, g = make_ranking(n_samples=n_samples, random_state=42, **kwargs)
rnd = np.random.RandomState(42)
......@@ -485,8 +491,9 @@ def test_training_does_not_fail_on_port_conflicts(cluster):
_, _, _, _, dX, dy, dw, _ = _create_data('binary-classification', output='array')
lightgbm_default_port = 12400
workers_hostname = _get_workers_hostname(cluster)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', lightgbm_default_port))
s.bind((workers_hostname, lightgbm_default_port))
dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5,
......@@ -1395,13 +1402,14 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert 'machines' not in params
# model 2 - machines given
workers_hostname = _get_workers_hostname(cluster)
n_workers = len(client.scheduler_info()['workers'])
open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model2 = dask_model_factory(
n_estimators=5,
num_leaves=5,
machines=",".join([
f"127.0.0.1:{port}"
f"{workers_hostname}:{port}"
for port in open_ports
]),
)
......@@ -1442,12 +1450,13 @@ def test_machines_should_be_used_if_provided(task, cluster):
n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
workers_hostname = _get_workers_hostname(cluster)
open_ports = lgb.dask._find_n_open_ports(n_workers)
dask_model = dask_model_factory(
n_estimators=5,
num_leaves=5,
machines=",".join([
f"127.0.0.1:{port}"
f"{workers_hostname}:{port}"
for port in open_ports
]),
)
......@@ -1457,7 +1466,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
error_msg = f"Binding port {open_ports[0]} failed"
with pytest.raises(lgb.basic.LightGBMError, match=error_msg):
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(('127.0.0.1', open_ports[0]))
s.bind((workers_hostname, open_ports[0]))
dask_model.fit(dX, dy, group=dg)
# The above error leaves a worker waiting
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment