"...AutoBuildImmortalWrt.git" did not exist on "1bec9b5e6f665491586a2d67db19005fbfe8e5ba"
Unverified Commit 296397df authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] raise more informative error for duplicates in 'machines' (fixes #4057) (#4059)

* [dask] raise more informative error for duplicates in 'machines'

* uncomment

* avoid test failure

* Revert "avoid test failure"

This reverts commit 9442bdf00f193a19a923dc0deb46b7822cb6f601.
parent b75a43a0
...@@ -153,6 +153,10 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[ ...@@ -153,6 +153,10 @@ def _machines_to_worker_map(machines: str, worker_addresses: List[str]) -> Dict[
Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use. Dictionary where keys are work addresses in the form expected by Dask and values are a port for LightGBM to use.
""" """
machine_addresses = machines.split(",") machine_addresses = machines.split(",")
if len(set(machine_addresses)) != len(machine_addresses):
raise ValueError(f"Found duplicates in 'machines' ({machines}). Each entry in 'machines' must be a unique IP-port combination.")
machine_to_port = defaultdict(set) machine_to_port = defaultdict(set)
for address in machine_addresses: for address in machine_addresses:
host, port = address.split(":") host, port = address.split(":")
......
...@@ -1116,6 +1116,7 @@ def test_machines_should_be_used_if_provided(task, output): ...@@ -1116,6 +1116,7 @@ def test_machines_should_be_used_if_provided(task, output):
client.rebalance() client.rebalance()
n_workers = len(client.scheduler_info()['workers']) n_workers = len(client.scheduler_info()['workers'])
assert n_workers > 1
open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)] open_ports = [lgb.dask._find_random_open_port() for _ in range(n_workers)]
dask_model = dask_model_factory( dask_model = dask_model_factory(
n_estimators=5, n_estimators=5,
...@@ -1134,6 +1135,17 @@ def test_machines_should_be_used_if_provided(task, output): ...@@ -1134,6 +1135,17 @@ def test_machines_should_be_used_if_provided(task, output):
s.bind(('127.0.0.1', open_ports[0])) s.bind(('127.0.0.1', open_ports[0]))
dask_model.fit(dX, dy, group=dg) dask_model.fit(dX, dy, group=dg)
# an informative error should be raised if "machines" has duplicates
one_open_port = lgb.dask._find_random_open_port()
dask_model.set_params(
machines=",".join([
"127.0.0.1:" + str(one_open_port)
for _ in range(n_workers)
])
)
with pytest.raises(ValueError, match="Found duplicates in 'machines'"):
dask_model.fit(dX, dy, group=dg)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"classes", "classes",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment