Unverified Commit bf1a604a authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] [dask] add type annotations on dask._HostWorkers (#5766)

parent 98c1db77
...@@ -7,7 +7,7 @@ dask.Array and dask.DataFrame collections. ...@@ -7,7 +7,7 @@ dask.Array and dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost. It is based on dask-lightgbm, which was based on dask-xgboost.
""" """
import socket import socket
from collections import defaultdict, namedtuple from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from enum import Enum, auto from enum import Enum, auto
from functools import partial from functools import partial
...@@ -37,7 +37,18 @@ _DaskVectorLike = Union[dask_Array, dask_Series] ...@@ -37,7 +37,18 @@ _DaskVectorLike = Union[dask_Array, dask_Series]
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix] _DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] _PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]
_HostWorkers = namedtuple('_HostWorkers', ['default', 'all'])
class _HostWorkers:
def __init__(self, default: str, all_workers: List[str]):
self.default = default
self.all_workers = all_workers
def __eq__(self, other: "_HostWorkers") -> bool:
return (
self.default == other.default
and self.all_workers == other.all_workers
)
class _DatasetNames(Enum): class _DatasetNames(Enum):
...@@ -105,9 +116,9 @@ def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWo ...@@ -105,9 +116,9 @@ def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWo
if not hostname: if not hostname:
raise ValueError(f"Could not parse host name from worker address '{address}'") raise ValueError(f"Could not parse host name from worker address '{address}'")
if hostname not in host_to_workers: if hostname not in host_to_workers:
host_to_workers[hostname] = _HostWorkers(default=address, all=[address]) host_to_workers[hostname] = _HostWorkers(default=address, all_workers=[address])
else: else:
host_to_workers[hostname].all.append(address) host_to_workers[hostname].all_workers.append(address)
return host_to_workers return host_to_workers
...@@ -124,7 +135,7 @@ def _assign_open_ports_to_workers( ...@@ -124,7 +135,7 @@ def _assign_open_ports_to_workers(
""" """
host_ports_futures = {} host_ports_futures = {}
for hostname, workers in host_to_workers.items(): for hostname, workers in host_to_workers.items():
n_workers_in_host = len(workers.all) n_workers_in_host = len(workers.all_workers)
host_ports_futures[hostname] = client.submit( host_ports_futures[hostname] = client.submit(
_find_n_open_ports, _find_n_open_ports,
n=n_workers_in_host, n=n_workers_in_host,
...@@ -135,7 +146,7 @@ def _assign_open_ports_to_workers( ...@@ -135,7 +146,7 @@ def _assign_open_ports_to_workers(
found_ports = client.gather(host_ports_futures) found_ports = client.gather(host_ports_futures)
worker_to_port = {} worker_to_port = {}
for hostname, workers in host_to_workers.items(): for hostname, workers in host_to_workers.items():
for worker, port in zip(workers.all, found_ports[hostname]): for worker, port in zip(workers.all_workers, found_ports[hostname]):
worker_to_port[worker] = port worker_to_port[worker] = port
return worker_to_port return worker_to_port
......
...@@ -525,7 +525,7 @@ def test_group_workers_by_host(): ...@@ -525,7 +525,7 @@ def test_group_workers_by_host():
expected = { expected = {
host: lgb.dask._HostWorkers( host: lgb.dask._HostWorkers(
default=f'tcp://{host}:0', default=f'tcp://{host}:0',
all=[f'tcp://{host}:0', f'tcp://{host}:1'] all_workers=[f'tcp://{host}:0', f'tcp://{host}:1']
) )
for host in hosts for host in hosts
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment