Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
bf1a604a
Unverified
Commit
bf1a604a
authored
Mar 06, 2023
by
James Lamb
Committed by
GitHub
Mar 06, 2023
Browse files
[python-package] [dask] add type annotations on dask._HostWorkers (#5766)
parent
98c1db77
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
7 deletions
+18
-7
python-package/lightgbm/dask.py
python-package/lightgbm/dask.py
+17
-6
tests/python_package_test/test_dask.py
tests/python_package_test/test_dask.py
+1
-1
No files found.
python-package/lightgbm/dask.py
View file @
bf1a604a
...
@@ -7,7 +7,7 @@ dask.Array and dask.DataFrame collections.
...
@@ -7,7 +7,7 @@ dask.Array and dask.DataFrame collections.
It is based on dask-lightgbm, which was based on dask-xgboost.
It is based on dask-lightgbm, which was based on dask-xgboost.
"""
"""
import
socket
import
socket
from
collections
import
defaultdict
,
namedtuple
from
collections
import
defaultdict
from
copy
import
deepcopy
from
copy
import
deepcopy
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
functools
import
partial
from
functools
import
partial
...
@@ -37,7 +37,18 @@ _DaskVectorLike = Union[dask_Array, dask_Series]
...
@@ -37,7 +37,18 @@ _DaskVectorLike = Union[dask_Array, dask_Series]
_DaskPart
=
Union
[
np
.
ndarray
,
pd_DataFrame
,
pd_Series
,
ss
.
spmatrix
]
_DaskPart
=
Union
[
np
.
ndarray
,
pd_DataFrame
,
pd_Series
,
ss
.
spmatrix
]
_PredictionDtype
=
Union
[
Type
[
np
.
float32
],
Type
[
np
.
float64
],
Type
[
np
.
int32
],
Type
[
np
.
int64
]]
_PredictionDtype
=
Union
[
Type
[
np
.
float32
],
Type
[
np
.
float64
],
Type
[
np
.
int32
],
Type
[
np
.
int64
]]
_HostWorkers
=
namedtuple
(
'_HostWorkers'
,
[
'default'
,
'all'
])
class
_HostWorkers
:
def
__init__
(
self
,
default
:
str
,
all_workers
:
List
[
str
]):
self
.
default
=
default
self
.
all_workers
=
all_workers
def
__eq__
(
self
,
other
:
"_HostWorkers"
)
->
bool
:
return
(
self
.
default
==
other
.
default
and
self
.
all_workers
==
other
.
all_workers
)
class
_DatasetNames
(
Enum
):
class
_DatasetNames
(
Enum
):
...
@@ -105,9 +116,9 @@ def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWo
...
@@ -105,9 +116,9 @@ def _group_workers_by_host(worker_addresses: Iterable[str]) -> Dict[str, _HostWo
if
not
hostname
:
if
not
hostname
:
raise
ValueError
(
f
"Could not parse host name from worker address '
{
address
}
'"
)
raise
ValueError
(
f
"Could not parse host name from worker address '
{
address
}
'"
)
if
hostname
not
in
host_to_workers
:
if
hostname
not
in
host_to_workers
:
host_to_workers
[
hostname
]
=
_HostWorkers
(
default
=
address
,
all
=
[
address
])
host_to_workers
[
hostname
]
=
_HostWorkers
(
default
=
address
,
all
_workers
=
[
address
])
else
:
else
:
host_to_workers
[
hostname
].
all
.
append
(
address
)
host_to_workers
[
hostname
].
all
_workers
.
append
(
address
)
return
host_to_workers
return
host_to_workers
...
@@ -124,7 +135,7 @@ def _assign_open_ports_to_workers(
...
@@ -124,7 +135,7 @@ def _assign_open_ports_to_workers(
"""
"""
host_ports_futures
=
{}
host_ports_futures
=
{}
for
hostname
,
workers
in
host_to_workers
.
items
():
for
hostname
,
workers
in
host_to_workers
.
items
():
n_workers_in_host
=
len
(
workers
.
all
)
n_workers_in_host
=
len
(
workers
.
all
_workers
)
host_ports_futures
[
hostname
]
=
client
.
submit
(
host_ports_futures
[
hostname
]
=
client
.
submit
(
_find_n_open_ports
,
_find_n_open_ports
,
n
=
n_workers_in_host
,
n
=
n_workers_in_host
,
...
@@ -135,7 +146,7 @@ def _assign_open_ports_to_workers(
...
@@ -135,7 +146,7 @@ def _assign_open_ports_to_workers(
found_ports
=
client
.
gather
(
host_ports_futures
)
found_ports
=
client
.
gather
(
host_ports_futures
)
worker_to_port
=
{}
worker_to_port
=
{}
for
hostname
,
workers
in
host_to_workers
.
items
():
for
hostname
,
workers
in
host_to_workers
.
items
():
for
worker
,
port
in
zip
(
workers
.
all
,
found_ports
[
hostname
]):
for
worker
,
port
in
zip
(
workers
.
all
_workers
,
found_ports
[
hostname
]):
worker_to_port
[
worker
]
=
port
worker_to_port
[
worker
]
=
port
return
worker_to_port
return
worker_to_port
...
...
tests/python_package_test/test_dask.py
View file @
bf1a604a
...
@@ -525,7 +525,7 @@ def test_group_workers_by_host():
...
@@ -525,7 +525,7 @@ def test_group_workers_by_host():
expected
=
{
expected
=
{
host
:
lgb
.
dask
.
_HostWorkers
(
host
:
lgb
.
dask
.
_HostWorkers
(
default
=
f
'tcp://
{
host
}
:0'
,
default
=
f
'tcp://
{
host
}
:0'
,
all
=
[
f
'tcp://
{
host
}
:0'
,
f
'tcp://
{
host
}
:1'
]
all
_workers
=
[
f
'tcp://
{
host
}
:0'
,
f
'tcp://
{
host
}
:1'
]
)
)
for
host
in
hosts
for
host
in
hosts
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment