"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "302f84bcc0f160b37230569f070ff18715f02baa"
Unverified Commit eb5f471b authored by imjwang's avatar imjwang Committed by GitHub
Browse files

[tests][dask] add scikit-learn compatibility tests (fixes #3894) (#3947)



* add test_dask.py

* Update tests/python_package_test/test_dask.py
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* clients

* remove ports

* safe sklearn checks

* safe sklearn checks

* fix whitespace

* fix whitespace-try 2

* fix whitespace-try 3

* isort

* isort

* sklearn_checks_to_learn
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent a3f4831d
...@@ -24,6 +24,7 @@ import dask.dataframe as dd ...@@ -24,6 +24,7 @@ import dask.dataframe as dd
import joblib import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import sklearn.utils.estimator_checks as sklearn_checks
from dask.array.utils import assert_eq from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster, default_client, wait from dask.distributed import Client, LocalCluster, default_client, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
...@@ -1081,3 +1082,36 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods): ...@@ -1081,3 +1082,36 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods):
for param in dask_spec.args: for param in dask_spec.args:
error_msg = f"param '{param}' has different default values in the methods" error_msg = f"param '{param}' has different default values in the methods"
assert dask_params[param].default == sklearn_params[param].default, error_msg assert dask_params[param].default == sklearn_params[param].default, error_msg
def sklearn_checks_to_run():
check_names = [
"check_estimator_get_tags_default_keys",
"check_get_params_invariance",
"check_set_params"
]
for check_name in check_names:
check_func = getattr(sklearn_checks, check_name, None)
if check_func:
yield check_func
def _tested_estimators():
for Estimator in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRegressor]:
yield Estimator()
@pytest.mark.parametrize("estimator", _tested_estimators())
@pytest.mark.parametrize("check", sklearn_checks_to_run())
def test_sklearn_integration(estimator, check, client):
estimator.set_params(local_listen_port=18000, time_out=5)
name = type(estimator).__name__
check(name, estimator)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
# this test is separate because it takes a not-yet-constructed estimator
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator):
name, Estimator = estimator.__class__.__name__, estimator.__class__
sklearn_checks.check_parameters_default_constructible(name, Estimator)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment