Unverified Commit eb5f471b authored by imjwang's avatar imjwang Committed by GitHub
Browse files

[tests][dask] add scikit-learn compatibility tests (fixes #3894) (#3947)



* add test_dask.py

* Update tests/python_package_test/test_dask.py
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* clients

* remove ports

* safe sklearn checks

* safe sklearn checks

* fix whitespace

* fix whitespace-try 2

* fix whitespace-try 3

* isort

* isort

* sklearn_checks_to_learn
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent a3f4831d
......@@ -24,6 +24,7 @@ import dask.dataframe as dd
import joblib
import numpy as np
import pandas as pd
import sklearn.utils.estimator_checks as sklearn_checks
from dask.array.utils import assert_eq
from dask.distributed import Client, LocalCluster, default_client, wait
from distributed.utils_test import client, cluster_fixture, gen_cluster, loop
......@@ -1081,3 +1082,36 @@ def test_dask_methods_and_sklearn_equivalents_have_similar_signatures(methods):
for param in dask_spec.args:
error_msg = f"param '{param}' has different default values in the methods"
assert dask_params[param].default == sklearn_params[param].default, error_msg
def sklearn_checks_to_run():
check_names = [
"check_estimator_get_tags_default_keys",
"check_get_params_invariance",
"check_set_params"
]
for check_name in check_names:
check_func = getattr(sklearn_checks, check_name, None)
if check_func:
yield check_func
def _tested_estimators():
for Estimator in [lgb.DaskLGBMClassifier, lgb.DaskLGBMRegressor]:
yield Estimator()
@pytest.mark.parametrize("estimator", _tested_estimators())
@pytest.mark.parametrize("check", sklearn_checks_to_run())
def test_sklearn_integration(estimator, check, client):
estimator.set_params(local_listen_port=18000, time_out=5)
name = type(estimator).__name__
check(name, estimator)
client.close(timeout=CLIENT_CLOSE_TIMEOUT)
# this test is separate because it takes a not-yet-constructed estimator
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator):
name, Estimator = estimator.__class__.__name__, estimator.__class__
sklearn_checks.check_parameters_default_constructible(name, Estimator)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment