Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
5857ef5e
Unverified
Commit
5857ef5e
authored
Sep 09, 2021
by
José Morales
Committed by
GitHub
Sep 09, 2021
Browse files
[tests][dask] Use workers hostname in tests (fixes #4594) (#4595)
Co-authored-by:
Nikita Titov
<
nekit94-12@hotmail.com
>
parent
d411bced
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
4 deletions
+13
-4
tests/python_package_test/test_dask.py
tests/python_package_test/test_dask.py
+13
-4
No files found.
tests/python_package_test/test_dask.py
View file @
5857ef5e
...
...
@@ -9,6 +9,7 @@ from itertools import groupby
from
os
import
getenv
from
platform
import
machine
from
sys
import
platform
from
urllib.parse
import
urlparse
import
pytest
...
...
@@ -87,6 +88,11 @@ def listen_port():
listen_port
.
port
=
13000
def
_get_workers_hostname
(
cluster
:
LocalCluster
)
->
str
:
one_worker_address
=
next
(
iter
(
cluster
.
scheduler_info
[
'workers'
]))
return
urlparse
(
one_worker_address
).
hostname
def
_create_ranking_data
(
n_samples
=
100
,
output
=
'array'
,
chunk_size
=
50
,
**
kwargs
):
X
,
y
,
g
=
make_ranking
(
n_samples
=
n_samples
,
random_state
=
42
,
**
kwargs
)
rnd
=
np
.
random
.
RandomState
(
42
)
...
...
@@ -485,8 +491,9 @@ def test_training_does_not_fail_on_port_conflicts(cluster):
_
,
_
,
_
,
_
,
dX
,
dy
,
dw
,
_
=
_create_data
(
'binary-classification'
,
output
=
'array'
)
lightgbm_default_port
=
12400
workers_hostname
=
_get_workers_hostname
(
cluster
)
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
'127.0.0.1'
,
lightgbm_default_port
))
s
.
bind
((
workers_hostname
,
lightgbm_default_port
))
dask_classifier
=
lgb
.
DaskLGBMClassifier
(
client
=
client
,
time_out
=
5
,
...
...
@@ -1395,13 +1402,14 @@ def test_network_params_not_required_but_respected_if_given(task, listen_port, c
assert
'machines'
not
in
params
# model 2 - machines given
workers_hostname
=
_get_workers_hostname
(
cluster
)
n_workers
=
len
(
client
.
scheduler_info
()[
'workers'
])
open_ports
=
lgb
.
dask
.
_find_n_open_ports
(
n_workers
)
dask_model2
=
dask_model_factory
(
n_estimators
=
5
,
num_leaves
=
5
,
machines
=
","
.
join
([
f
"
127.0.0.1
:
{
port
}
"
f
"
{
workers_hostname
}
:
{
port
}
"
for
port
in
open_ports
]),
)
...
...
@@ -1442,12 +1450,13 @@ def test_machines_should_be_used_if_provided(task, cluster):
n_workers
=
len
(
client
.
scheduler_info
()[
'workers'
])
assert
n_workers
>
1
workers_hostname
=
_get_workers_hostname
(
cluster
)
open_ports
=
lgb
.
dask
.
_find_n_open_ports
(
n_workers
)
dask_model
=
dask_model_factory
(
n_estimators
=
5
,
num_leaves
=
5
,
machines
=
","
.
join
([
f
"
127.0.0.1
:
{
port
}
"
f
"
{
workers_hostname
}
:
{
port
}
"
for
port
in
open_ports
]),
)
...
...
@@ -1457,7 +1466,7 @@ def test_machines_should_be_used_if_provided(task, cluster):
error_msg
=
f
"Binding port
{
open_ports
[
0
]
}
failed"
with
pytest
.
raises
(
lgb
.
basic
.
LightGBMError
,
match
=
error_msg
):
with
socket
.
socket
(
socket
.
AF_INET
,
socket
.
SOCK_STREAM
)
as
s
:
s
.
bind
((
'127.0.0.1'
,
open_ports
[
0
]))
s
.
bind
((
workers_hostname
,
open_ports
[
0
]))
dask_model
.
fit
(
dX
,
dy
,
group
=
dg
)
# The above error leaves a worker waiting
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment