Unverified Commit 2a00b6ff authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[dask] [ci] add support for scikit-learn 0.24+ in tests (fixes #4031) (#4032)



* [dask] [ci] add support for scikit-learn 0.24+ in tests (fixes #4031)

* Update tests/python_package_test/test_dask.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* try upgrading mixtexsetup

* they changed the executable name UGH

* more changes for executable name

* another path change

* changing package mirrors

* undo experiments
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 6356e659
...@@ -28,12 +28,16 @@ import sklearn.utils.estimator_checks as sklearn_checks ...@@ -28,12 +28,16 @@ import sklearn.utils.estimator_checks as sklearn_checks
from dask.array.utils import assert_eq from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster, default_client, wait from dask.distributed import Client, LocalCluster, default_client, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
from pkg_resources import parse_version
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.stats import spearmanr from scipy.stats import spearmanr
from sklearn import __version__ as sk_version
from sklearn.datasets import make_blobs, make_regression from sklearn.datasets import make_blobs, make_regression
from .utils import make_ranking from .utils import make_ranking
sk_version = parse_version(sk_version)
# time, in seconds, to wait for the Dask client to close. Used to avoid teardown errors # time, in seconds, to wait for the Dask client to close. Used to avoid teardown errors
# see https://distributed.dask.org/en/latest/api.html#distributed.Client.close # see https://distributed.dask.org/en/latest/api.html#distributed.Client.close
CLIENT_CLOSE_TIMEOUT = 120 CLIENT_CLOSE_TIMEOUT = 120
...@@ -1253,5 +1257,9 @@ def test_sklearn_integration(estimator, check, client): ...@@ -1253,5 +1257,9 @@ def test_sklearn_integration(estimator, check, client):
# this test is separate because it takes a not-yet-constructed estimator # this test is separate because it takes a not-yet-constructed estimator
@pytest.mark.parametrize("estimator", list(_tested_estimators())) @pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator): def test_parameters_default_constructible(estimator):
name, Estimator = estimator.__class__.__name__, estimator.__class__ name = estimator.__class__.__name__
if sk_version >= parse_version("0.24"):
Estimator = estimator
else:
Estimator = estimator.__class__
sklearn_checks.check_parameters_default_constructible(name, Estimator) sklearn_checks.check_parameters_default_constructible(name, Estimator)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment