Unverified Commit 2315c0d1 authored by Guillaume Lemaitre's avatar Guillaume Lemaitre Committed by GitHub
Browse files

[tests][python][sklearn] make sklearn integration test compatible with 0.24 (#3533)

* TST make sklearn integration test compatible with 0.24

* remove useless import

* remove outdated comment

* order import

* use parametrize_with_checks

* change the reason

* skip constructible if != 0.23

* make tests behave the same across sklearn version

* linter

* address suggestions
parent 6c10c4ca
...@@ -315,8 +315,16 @@ class LGBMModel(_LGBMModelBase): ...@@ -315,8 +315,16 @@ class LGBMModel(_LGBMModelBase):
self.set_params(**kwargs) self.set_params(**kwargs)
def _more_tags(self): def _more_tags(self):
return {'allow_nan': True, return {
'X_types': ['2darray', 'sparse', '1dlabels']} 'allow_nan': True,
'X_types': ['2darray', 'sparse', '1dlabels'],
'_xfail_checks': {
'check_no_attributes_set_in_init':
'scikit-learn incorrectly asserts that private attributes '
'cannot be set in __init__: '
'(see https://github.com/microsoft/LightGBM/issues/2628)'
}
}
def get_params(self, deep=True): def get_params(self, deep=True):
"""Get parameters for this estimator. """Get parameters for this estimator.
......
...@@ -4,24 +4,30 @@ import joblib ...@@ -4,24 +4,30 @@ import joblib
import math import math
import os import os
import unittest import unittest
import warnings
import lightgbm as lgb import lightgbm as lgb
import numpy as np import numpy as np
import pytest
from pkg_resources import parse_version
from sklearn import __version__ as sk_version from sklearn import __version__ as sk_version
from sklearn.base import clone from sklearn.base import clone
from sklearn.datasets import load_svmlight_file, make_multilabel_classification from sklearn.datasets import load_svmlight_file, make_multilabel_classification
from sklearn.exceptions import SkipTestWarning from sklearn.utils.estimator_checks import check_parameters_default_constructible
from sklearn.metrics import log_loss, mean_squared_error from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
from sklearn.multioutput import (MultiOutputClassifier, ClassifierChain, MultiOutputRegressor, from sklearn.multioutput import (MultiOutputClassifier, ClassifierChain, MultiOutputRegressor,
RegressorChain) RegressorChain)
from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest,
check_parameters_default_constructible)
from sklearn.utils.validation import check_is_fitted from sklearn.utils.validation import check_is_fitted
from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud
sk_version = parse_version(sk_version)
if sk_version < parse_version("0.23"):
import warnings
from sklearn.exceptions import SkipTestWarning
from sklearn.utils.estimator_checks import _yield_all_checks, SkipTest
else:
from sklearn.utils.estimator_checks import parametrize_with_checks
decreasing_generator = itertools.count(0, -1) decreasing_generator = itertools.count(0, -1)
...@@ -168,7 +174,7 @@ class TestSklearn(unittest.TestCase): ...@@ -168,7 +174,7 @@ class TestSklearn(unittest.TestCase):
self.assertLessEqual(score, 1.) self.assertLessEqual(score, 1.)
# sklearn <0.23 does not have a stacking classifier and n_features_in_ property # sklearn <0.23 does not have a stacking classifier and n_features_in_ property
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23') @unittest.skipIf(sk_version < parse_version("0.23"), 'scikit-learn version is less than 0.23')
def test_stacking_classifier(self): def test_stacking_classifier(self):
from sklearn.ensemble import StackingClassifier from sklearn.ensemble import StackingClassifier
...@@ -195,7 +201,7 @@ class TestSklearn(unittest.TestCase): ...@@ -195,7 +201,7 @@ class TestSklearn(unittest.TestCase):
self.assertTrue(all(classes)) self.assertTrue(all(classes))
# sklearn <0.23 does not have a stacking regressor and n_features_in_ property # sklearn <0.23 does not have a stacking regressor and n_features_in_ property
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23') @unittest.skipIf(sk_version < parse_version('0.23'), 'scikit-learn version is less than 0.23')
def test_stacking_regressor(self): def test_stacking_regressor(self):
from sklearn.ensemble import StackingRegressor from sklearn.ensemble import StackingRegressor
...@@ -280,7 +286,7 @@ class TestSklearn(unittest.TestCase): ...@@ -280,7 +286,7 @@ class TestSklearn(unittest.TestCase):
self.assertLessEqual(score, 1.) self.assertLessEqual(score, 1.)
# sklearn < 0.22 does not have the post fit attribute: classes_ # sklearn < 0.22 does not have the post fit attribute: classes_
@unittest.skipIf(sk_version < '0.22.0', 'scikit-learn version is less than 0.22') @unittest.skipIf(sk_version < parse_version('0.22'), 'scikit-learn version is less than 0.22')
def test_multioutput_classifier(self): def test_multioutput_classifier(self):
n_outputs = 3 n_outputs = 3
X, y = make_multilabel_classification(n_samples=100, n_features=20, X, y = make_multilabel_classification(n_samples=100, n_features=20,
...@@ -300,7 +306,7 @@ class TestSklearn(unittest.TestCase): ...@@ -300,7 +306,7 @@ class TestSklearn(unittest.TestCase):
self.assertIsInstance(classifier.booster_, lgb.Booster) self.assertIsInstance(classifier.booster_, lgb.Booster)
# sklearn < 0.23 does not have as_frame parameter # sklearn < 0.23 does not have as_frame parameter
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23') @unittest.skipIf(sk_version < parse_version('0.23'), 'scikit-learn version is less than 0.23')
def test_multioutput_regressor(self): def test_multioutput_regressor(self):
bunch = load_linnerud(as_frame=True) # returns a Bunch instance bunch = load_linnerud(as_frame=True) # returns a Bunch instance
X, y = bunch['data'], bunch['target'] X, y = bunch['data'], bunch['target']
...@@ -317,7 +323,7 @@ class TestSklearn(unittest.TestCase): ...@@ -317,7 +323,7 @@ class TestSklearn(unittest.TestCase):
self.assertIsInstance(regressor.booster_, lgb.Booster) self.assertIsInstance(regressor.booster_, lgb.Booster)
# sklearn < 0.22 does not have the post fit attribute: classes_ # sklearn < 0.22 does not have the post fit attribute: classes_
@unittest.skipIf(sk_version < '0.22.0', 'scikit-learn version is less than 0.22') @unittest.skipIf(sk_version < parse_version('0.22'), 'scikit-learn version is less than 0.22')
def test_classifier_chain(self): def test_classifier_chain(self):
n_outputs = 3 n_outputs = 3
X, y = make_multilabel_classification(n_samples=100, n_features=20, X, y = make_multilabel_classification(n_samples=100, n_features=20,
...@@ -339,7 +345,7 @@ class TestSklearn(unittest.TestCase): ...@@ -339,7 +345,7 @@ class TestSklearn(unittest.TestCase):
self.assertIsInstance(classifier.booster_, lgb.Booster) self.assertIsInstance(classifier.booster_, lgb.Booster)
# sklearn < 0.23 does not have as_frame parameter # sklearn < 0.23 does not have as_frame parameter
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23') @unittest.skipIf(sk_version < parse_version('0.23'), 'scikit-learn version is less than 0.23')
def test_regressor_chain(self): def test_regressor_chain(self):
bunch = load_linnerud(as_frame=True) # returns a Bunch instance bunch = load_linnerud(as_frame=True) # returns a Bunch instance
X, y = bunch['data'], bunch['target'] X, y = bunch['data'], bunch['target']
...@@ -452,29 +458,6 @@ class TestSklearn(unittest.TestCase): ...@@ -452,29 +458,6 @@ class TestSklearn(unittest.TestCase):
importance_gain_top1 = sorted(importances_gain, reverse=True)[0] importance_gain_top1 = sorted(importances_gain, reverse=True)[0]
self.assertNotEqual(importance_split_top1, importance_gain_top1) self.assertNotEqual(importance_split_top1, importance_gain_top1)
# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1
@unittest.skipIf(sk_version < '0.19.0', 'scikit-learn version is less than 0.19')
def test_sklearn_integration(self):
# we cannot use `check_estimator` directly since there is no skip test mechanism
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier),
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)):
check_parameters_default_constructible(name, estimator)
# we cannot leave default params (see https://github.com/microsoft/LightGBM/issues/833)
estimator = estimator(min_child_samples=1, min_data_in_bin=1)
for check in _yield_all_checks(name, estimator):
check_name = check.func.__name__ if hasattr(check, 'func') else check.__name__
if check_name == 'check_estimators_nan_inf':
continue # skip test because LightGBM deals with nan
elif check_name == "check_no_attributes_set_in_init":
# skip test because scikit-learn incorrectly asserts that
# private attributes cannot be set in __init__
# (see https://github.com/microsoft/LightGBM/issues/2628)
continue
try:
check(name, estimator)
except SkipTest as message:
warnings.warn(message, SkipTestWarning)
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed') @unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self): def test_pandas_categorical(self):
import pandas as pd import pandas as pd
...@@ -1131,7 +1114,7 @@ class TestSklearn(unittest.TestCase): ...@@ -1131,7 +1114,7 @@ class TestSklearn(unittest.TestCase):
init_gbm.evals_result_['valid_0']['multi_logloss'][-1]) init_gbm.evals_result_['valid_0']['multi_logloss'][-1])
# sklearn < 0.22 requires passing "attributes" argument # sklearn < 0.22 requires passing "attributes" argument
@unittest.skipIf(sk_version < '0.22.0', 'scikit-learn version is less than 0.22') @unittest.skipIf(sk_version < parse_version('0.22'), 'scikit-learn version is less than 0.22')
def test_check_is_fitted(self): def test_check_is_fitted(self):
X, y = load_digits(n_class=2, return_X_y=True) X, y = load_digits(n_class=2, return_X_y=True)
est = lgb.LGBMModel(n_estimators=5, objective="binary") est = lgb.LGBMModel(n_estimators=5, objective="binary")
...@@ -1149,3 +1132,49 @@ class TestSklearn(unittest.TestCase): ...@@ -1149,3 +1132,49 @@ class TestSklearn(unittest.TestCase):
rnk.fit(X, y, group=np.ones(X.shape[0])) rnk.fit(X, y, group=np.ones(X.shape[0]))
for model in models: for model in models:
check_is_fitted(model) check_is_fitted(model)
def _tested_estimators():
for Estimator in [lgb.sklearn.LGBMClassifier, lgb.sklearn.LGBMRegressor]:
yield Estimator()
if sk_version < parse_version("0.23"):
def _generate_checks_per_estimator(check_generator, estimators):
for estimator in estimators:
name = estimator.__class__.__name__
for check in check_generator(name, estimator):
yield estimator, check
@pytest.mark.skipif(
sk_version < parse_version("0.21"), reason="scikit-learn version is less than 0.21"
)
@pytest.mark.parametrize(
"estimator, check",
_generate_checks_per_estimator(_yield_all_checks, _tested_estimators()),
)
def test_sklearn_integration(estimator, check):
xfail_checks = estimator._get_tags()["_xfail_checks"]
check_name = check.__name__ if hasattr(check, "__name__") else check.func.__name__
if xfail_checks and check_name in xfail_checks:
warnings.warn(xfail_checks[check_name], SkipTestWarning)
raise SkipTest
estimator.set_params(min_child_samples=1, min_data_in_bin=1)
name = estimator.__class__.__name__
check(name, estimator)
else:
@parametrize_with_checks(list(_tested_estimators()))
def test_sklearn_integration(estimator, check, request):
estimator.set_params(min_child_samples=1, min_data_in_bin=1)
check(estimator)
@pytest.mark.skipif(
sk_version >= parse_version("0.24"),
reason="Default constructible check included in common check from 0.24"
)
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator):
name, Estimator = estimator.__class__.__name__, estimator.__class__
# Test that estimators are default-constructible
check_parameters_default_constructible(name, Estimator)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment