Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
a4cae37c
Unverified
Commit
a4cae37c
authored
Feb 02, 2021
by
Frank Fineis
Committed by
GitHub
Feb 02, 2021
Browse files
rebalance dask.array ranker input (#3892)
parent
0c71be74
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
2 deletions
+11
-2
tests/python_package_test/test_dask.py
tests/python_package_test/test_dask.py
+11
-2
No files found.
tests/python_package_test/test_dask.py
View file @
a4cae37c
...
@@ -19,10 +19,10 @@ import numpy as np
...
@@ -19,10 +19,10 @@ import numpy as np
import
pandas
as
pd
import
pandas
as
pd
from
scipy.stats
import
spearmanr
from
scipy.stats
import
spearmanr
from
dask.array.utils
import
assert_eq
from
dask.array.utils
import
assert_eq
from
dask.distributed
import
wait
from
distributed.utils_test
import
client
,
cluster_fixture
,
gen_cluster
,
loop
from
distributed.utils_test
import
client
,
cluster_fixture
,
gen_cluster
,
loop
from
scipy.sparse
import
csr_matrix
from
scipy.sparse
import
csr_matrix
from
sklearn.datasets
import
make_blobs
,
make_regression
from
sklearn.datasets
import
make_blobs
,
make_regression
from
sklearn.utils
import
check_random_state
from
.utils
import
make_ranking
from
.utils
import
make_ranking
...
@@ -382,6 +382,15 @@ def test_ranker(output, client, listen_port, group):
...
@@ -382,6 +382,15 @@ def test_ranker(output, client, listen_port, group):
group
=
group
group
=
group
)
)
# rebalance small dask.array dataset for better performance.
if
output
==
'array'
:
dX
=
dX
.
persist
()
dy
=
dy
.
persist
()
dw
=
dw
.
persist
()
dg
=
dg
.
persist
()
_
=
wait
([
dX
,
dy
,
dw
,
dg
])
client
.
rebalance
()
# use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of
# use many trees + leaves to overfit, help ensure that dask data-parallel strategy matches that of
# serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210.
# serial learner. See https://github.com/microsoft/LightGBM/issues/3292#issuecomment-671288210.
params
=
{
params
=
{
...
@@ -409,7 +418,7 @@ def test_ranker(output, client, listen_port, group):
...
@@ -409,7 +418,7 @@ def test_ranker(output, client, listen_port, group):
# have high rank correlation with scores from serial ranker.
# have high rank correlation with scores from serial ranker.
dcor
=
spearmanr
(
rnkvec_dask
,
y
).
correlation
dcor
=
spearmanr
(
rnkvec_dask
,
y
).
correlation
assert
dcor
>
0.6
assert
dcor
>
0.6
assert
spearmanr
(
rnkvec_dask
,
rnkvec_local
).
correlation
>
0.
75
assert
spearmanr
(
rnkvec_dask
,
rnkvec_local
).
correlation
>
0.
8
assert_eq
(
rnkvec_dask
,
rnkvec_dask_local
)
assert_eq
(
rnkvec_dask
,
rnkvec_dask_local
)
client
.
close
(
timeout
=
CLIENT_CLOSE_TIMEOUT
)
client
.
close
(
timeout
=
CLIENT_CLOSE_TIMEOUT
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment