Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
93f2da43
Unverified
Commit
93f2da43
authored
Nov 21, 2022
by
José Morales
Committed by
GitHub
Nov 21, 2022
Browse files
[tests][dask] fix workers without data test (fixes #5537) (#5544)
parent
2d4654a1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
36 additions
and
30 deletions
+36
-30
tests/python_package_test/test_dask.py
tests/python_package_test/test_dask.py
+36
-30
No files found.
tests/python_package_test/test_dask.py
View file @
93f2da43
...
@@ -11,6 +11,7 @@ from sys import platform
...
@@ -11,6 +11,7 @@ from sys import platform
from
urllib.parse
import
urlparse
from
urllib.parse
import
urlparse
import
pytest
import
pytest
from
sklearn.metrics
import
accuracy_score
,
r2_score
import
lightgbm
as
lgb
import
lightgbm
as
lgb
...
@@ -75,6 +76,13 @@ def cluster2():
...
@@ -75,6 +76,13 @@ def cluster2():
dask_cluster
.
close
()
dask_cluster
.
close
()
@
pytest
.
fixture
(
scope
=
'module'
)
def
cluster_three_workers
():
dask_cluster
=
LocalCluster
(
n_workers
=
3
,
threads_per_worker
=
1
,
dashboard_address
=
None
)
yield
dask_cluster
dask_cluster
.
close
()
@
pytest
.
fixture
()
@
pytest
.
fixture
()
def
listen_port
():
def
listen_port
():
listen_port
.
port
+=
10
listen_port
.
port
+=
10
...
@@ -1503,56 +1511,54 @@ def test_errors(cluster):
...
@@ -1503,56 +1511,54 @@ def test_errors(cluster):
@
pytest
.
mark
.
parametrize
(
'task'
,
tasks
)
@
pytest
.
mark
.
parametrize
(
'task'
,
tasks
)
@
pytest
.
mark
.
parametrize
(
'output'
,
data_output
)
@
pytest
.
mark
.
parametrize
(
'output'
,
data_output
)
def
test_training_succeeds_even_if_some_workers_do_not_have_any_data
(
task
,
output
,
cluster
):
def
test_training_succeeds_even_if_some_workers_do_not_have_any_data
(
task
,
output
,
cluster_three_workers
):
pytest
.
skip
(
"skipping due to timeout issues discussed in https://github.com/microsoft/LightGBM/pull/5510"
)
if
task
==
'ranking'
and
output
==
'scipy_csr_matrix'
:
if
task
==
'ranking'
and
output
==
'scipy_csr_matrix'
:
pytest
.
skip
(
'LGBMRanker is not currently tested on sparse matrices'
)
pytest
.
skip
(
'LGBMRanker is not currently tested on sparse matrices'
)
with
Client
(
cluster
)
as
client
:
with
Client
(
cluster_three_workers
)
as
client
:
def
collection_to_single_partition
(
collection
):
_
,
y
,
_
,
_
,
dX
,
dy
,
dw
,
dg
=
_create_data
(
"""Merge the parts of a Dask collection into a single partition."""
if
collection
is
None
:
return
if
isinstance
(
collection
,
da
.
Array
):
return
collection
.
rechunk
(
*
collection
.
shape
)
return
collection
.
repartition
(
npartitions
=
1
)
X
,
y
,
w
,
g
,
dX
,
dy
,
dw
,
dg
=
_create_data
(
objective
=
task
,
objective
=
task
,
output
=
output
,
output
=
output
,
group
=
None
group
=
None
,
n_samples
=
1_000
,
chunk_size
=
200
,
)
)
dask_model_factory
=
task_to_dask_factory
[
task
]
dask_model_factory
=
task_to_dask_factory
[
task
]
local_model_factory
=
task_to_local_factory
[
task
]
dX
=
collection_to_single_partition
(
dX
)
workers
=
list
(
client
.
scheduler_info
()[
'workers'
].
keys
())
dy
=
collection_to_single_partition
(
dy
)
assert
len
(
workers
)
==
3
dw
=
collection_to_single_partition
(
dw
)
first_two_workers
=
workers
[:
2
]
dg
=
collection_to_single_partition
(
dg
)
n_workers
=
len
(
client
.
scheduler_info
()[
'workers'
])
dX
=
client
.
persist
(
dX
,
workers
=
first_two_workers
)
assert
n_workers
>
1
dy
=
client
.
persist
(
dy
,
workers
=
first_two_workers
)
assert
dX
.
npartitions
==
1
dw
=
client
.
persist
(
dw
,
workers
=
first_two_workers
)
wait
([
dX
,
dy
,
dw
])
workers_with_data
=
set
()
for
coll
in
(
dX
,
dy
,
dw
):
for
with_data
in
client
.
who_has
(
coll
).
values
():
workers_with_data
.
update
(
with_data
)
assert
workers
[
2
]
not
in
with_data
assert
len
(
workers_with_data
)
==
2
params
=
{
params
=
{
'time_out'
:
5
,
'time_out'
:
5
,
'random_state'
:
42
,
'random_state'
:
42
,
'num_leaves'
:
10
'num_leaves'
:
10
,
'n_estimators'
:
20
,
}
}
dask_model
=
dask_model_factory
(
tree
=
'data'
,
client
=
client
,
**
params
)
dask_model
=
dask_model_factory
(
tree
=
'data'
,
client
=
client
,
**
params
)
dask_model
.
fit
(
dX
,
dy
,
group
=
dg
,
sample_weight
=
dw
)
dask_model
.
fit
(
dX
,
dy
,
group
=
dg
,
sample_weight
=
dw
)
dask_preds
=
dask_model
.
predict
(
dX
).
compute
()
dask_preds
=
dask_model
.
predict
(
dX
).
compute
()
if
task
==
'regression'
:
local_model
=
local_model_factory
(
**
param
s
)
score
=
r2_score
(
y
,
dask_pred
s
)
if
task
==
'ranking'
:
el
if
task
.
endswith
(
'classification'
)
:
local_model
.
fit
(
X
,
y
,
group
=
g
,
sample_weight
=
w
)
score
=
accuracy_score
(
y
,
dask_preds
)
else
:
else
:
local_model
.
fit
(
X
,
y
,
sample_weight
=
w
)
score
=
spearmanr
(
dask_preds
,
y
).
correlation
local_preds
=
local_model
.
predict
(
X
)
assert
score
>
0.9
assert
assert_eq
(
dask_preds
,
local_preds
)
@
pytest
.
mark
.
parametrize
(
'task'
,
tasks
)
@
pytest
.
mark
.
parametrize
(
'task'
,
tasks
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment