Unverified Commit bf1a604a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] [dask] add type annotations on dask._HostWorkers (#5766)

parent 98c1db77
......@@ -7,7 +7,7 @@ dask.Array and dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
import socket
from collections import defaultdict, namedtuple
from collections import defaultdict
from copy import deepcopy
from enum import Enum, auto
from functools import partial
......@@ -37,7 +37,18 @@ _DaskVectorLike = Union[dask_Array, dask_Series]
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
_HostWorkers = namedtuple('_HostWorkers', ['default', 'all'])
class _HostWorkers:
def __init__(self, default: str, all_workers: List[str]):
self.default = default
self.all_workers = all_workers
def __eq__(self, other: "_HostWorkers") -> bool:
return (
self.default == other.default
and self.all_workers == other.all_workers
)
class _DatasetNames(Enum):
......@@ -105,9 +116,9 @@ def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWo
if not hostname:
raise ValueError(f"Could not parse host name from worker address '{address}'")
if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all=[address])
host_to_workers[hostname] = _HostWorkers(default=address, all_workers=[address])
else:
host_to_workers[hostname].all.append(address)
host_to_workers[hostname].all_workers.append(address)
return host_to_workers
......@@ -124,7 +135,7 @@ def _assign_open_ports_to_workers(
"""
host_ports_futures = {}
for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all)
n_workers_in_host = len(workers.all_workers)
host_ports_futures[hostname] = client.submit(
_find_n_open_ports,
n=n_workers_in_host,
......@@ -135,7 +146,7 @@ def _assign_open_ports_to_workers(
found_ports = client.gather(host_ports_futures)
worker_to_port = {}
for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all, found_ports[hostname]):
for worker, port in zip(workers.all_workers, found_ports[hostname]):
worker_to_port[worker] = port
return worker_to_port
......
......@@ -525,7 +525,7 @@ def test_group_workers_by_host():
expected = {
host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0',
all=[f'tcp://{host}:0', f'tcp://{host}:1']
all_workers=[f'tcp://{host}:0', f'tcp://{host}:1']
)
for host in hosts
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment