Unverified Commit 2315c0d1 authored by Guillaume Lemaitre's avatar Guillaume Lemaitre Committed by GitHub
Browse files

[tests][python][sklearn] make sklearn integration test compatible with 0.24 (#3533)

* TST make sklearn integration test compatible with 0.24

* remove useless import

* remove outdated comment

* order import

* use parametrize_with_checks

* change the reason

* skip constructible if != 0.23

* make tests behave the same across sklearn version

* linter

* address suggestions
parent 6c10c4ca
......@@ -315,8 +315,16 @@ class LGBMModel(_LGBMModelBase):
self.set_params(**kwargs)
def _more_tags(self):
return {'allow_nan': True,
'X_types': ['2darray', 'sparse', '1dlabels']}
return {
'allow_nan': True,
'X_types': ['2darray', 'sparse', '1dlabels'],
'_xfail_checks': {
'check_no_attributes_set_in_init':
'scikit-learn incorrectly asserts that private attributes '
'cannot be set in __init__: '
'(see https://github.com/microsoft/LightGBM/issues/2628)'
}
}
def get_params(self, deep=True):
"""Get parameters for this estimator.
......
......@@ -4,24 +4,30 @@ import joblib
import math
import os
import unittest
import warnings
import lightgbm as lgb
import numpy as np
import pytest
from pkg_resources import parse_version
from sklearn import __version__ as sk_version
from sklearn.base import clone
from sklearn.datasets import load_svmlight_file, make_multilabel_classification
from sklearn.exceptions import SkipTestWarning
from sklearn.utils.estimator_checks import check_parameters_default_constructible
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
from sklearn.multioutput import (MultiOutputClassifier, ClassifierChain, MultiOutputRegressor,
RegressorChain)
from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest,
check_parameters_default_constructible)
from sklearn.utils.validation import check_is_fitted
from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud
sk_version = parse_version(sk_version)
if sk_version < parse_version("0.23"):
import warnings
from sklearn.exceptions import SkipTestWarning
from sklearn.utils.estimator_checks import _yield_all_checks, SkipTest
else:
from sklearn.utils.estimator_checks import parametrize_with_checks
decreasing_generator = itertools.count(0, -1)
......@@ -168,7 +174,7 @@ class TestSklearn(unittest.TestCase):
self.assertLessEqual(score, 1.)
# sklearn <0.23 does not have a stacking classifier and n_features_in_ property
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23')
@unittest.skipIf(sk_version < parse_version("0.23"), 'scikit-learn version is less than 0.23')
def test_stacking_classifier(self):
from sklearn.ensemble import StackingClassifier
......@@ -195,7 +201,7 @@ class TestSklearn(unittest.TestCase):
self.assertTrue(all(classes))
# sklearn <0.23 does not have a stacking regressor and n_features_in_ property
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23')
@unittest.skipIf(sk_version < parse_version('0.23'), 'scikit-learn version is less than 0.23')
def test_stacking_regressor(self):
from sklearn.ensemble import StackingRegressor
......@@ -280,7 +286,7 @@ class TestSklearn(unittest.TestCase):
self.assertLessEqual(score, 1.)
# sklearn < 0.22 does not have the post fit attribute: classes_
@unittest.skipIf(sk_version < '0.22.0', 'scikit-learn version is less than 0.22')
@unittest.skipIf(sk_version < parse_version('0.22'), 'scikit-learn version is less than 0.22')
def test_multioutput_classifier(self):
n_outputs = 3
X, y = make_multilabel_classification(n_samples=100, n_features=20,
......@@ -300,7 +306,7 @@ class TestSklearn(unittest.TestCase):
self.assertIsInstance(classifier.booster_, lgb.Booster)
# sklearn < 0.23 does not have as_frame parameter
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23')
@unittest.skipIf(sk_version < parse_version('0.23'), 'scikit-learn version is less than 0.23')
def test_multioutput_regressor(self):
bunch = load_linnerud(as_frame=True) # returns a Bunch instance
X, y = bunch['data'], bunch['target']
......@@ -317,7 +323,7 @@ class TestSklearn(unittest.TestCase):
self.assertIsInstance(regressor.booster_, lgb.Booster)
# sklearn < 0.22 does not have the post fit attribute: classes_
@unittest.skipIf(sk_version < '0.22.0', 'scikit-learn version is less than 0.22')
@unittest.skipIf(sk_version < parse_version('0.22'), 'scikit-learn version is less than 0.22')
def test_classifier_chain(self):
n_outputs = 3
X, y = make_multilabel_classification(n_samples=100, n_features=20,
......@@ -339,7 +345,7 @@ class TestSklearn(unittest.TestCase):
self.assertIsInstance(classifier.booster_, lgb.Booster)
# sklearn < 0.23 does not have as_frame parameter
@unittest.skipIf(sk_version < '0.23.0', 'scikit-learn version is less than 0.23')
@unittest.skipIf(sk_version < parse_version('0.23'), 'scikit-learn version is less than 0.23')
def test_regressor_chain(self):
bunch = load_linnerud(as_frame=True) # returns a Bunch instance
X, y = bunch['data'], bunch['target']
......@@ -452,29 +458,6 @@ class TestSklearn(unittest.TestCase):
importance_gain_top1 = sorted(importances_gain, reverse=True)[0]
self.assertNotEqual(importance_split_top1, importance_gain_top1)
# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1
@unittest.skipIf(sk_version < '0.19.0', 'scikit-learn version is less than 0.19')
def test_sklearn_integration(self):
# we cannot use `check_estimator` directly since there is no skip test mechanism
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier),
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)):
check_parameters_default_constructible(name, estimator)
# we cannot leave default params (see https://github.com/microsoft/LightGBM/issues/833)
estimator = estimator(min_child_samples=1, min_data_in_bin=1)
for check in _yield_all_checks(name, estimator):
check_name = check.func.__name__ if hasattr(check, 'func') else check.__name__
if check_name == 'check_estimators_nan_inf':
continue # skip test because LightGBM deals with nan
elif check_name == "check_no_attributes_set_in_init":
# skip test because scikit-learn incorrectly asserts that
# private attributes cannot be set in __init__
# (see https://github.com/microsoft/LightGBM/issues/2628)
continue
try:
check(name, estimator)
except SkipTest as message:
warnings.warn(message, SkipTestWarning)
@unittest.skipIf(not lgb.compat.PANDAS_INSTALLED, 'pandas is not installed')
def test_pandas_categorical(self):
import pandas as pd
......@@ -1131,7 +1114,7 @@ class TestSklearn(unittest.TestCase):
init_gbm.evals_result_['valid_0']['multi_logloss'][-1])
# sklearn < 0.22 requires passing "attributes" argument
@unittest.skipIf(sk_version < '0.22.0', 'scikit-learn version is less than 0.22')
@unittest.skipIf(sk_version < parse_version('0.22'), 'scikit-learn version is less than 0.22')
def test_check_is_fitted(self):
X, y = load_digits(n_class=2, return_X_y=True)
est = lgb.LGBMModel(n_estimators=5, objective="binary")
......@@ -1149,3 +1132,49 @@ class TestSklearn(unittest.TestCase):
rnk.fit(X, y, group=np.ones(X.shape[0]))
for model in models:
check_is_fitted(model)
def _tested_estimators():
for Estimator in [lgb.sklearn.LGBMClassifier, lgb.sklearn.LGBMRegressor]:
yield Estimator()
if sk_version < parse_version("0.23"):
def _generate_checks_per_estimator(check_generator, estimators):
for estimator in estimators:
name = estimator.__class__.__name__
for check in check_generator(name, estimator):
yield estimator, check
@pytest.mark.skipif(
sk_version < parse_version("0.21"), reason="scikit-learn version is less than 0.21"
)
@pytest.mark.parametrize(
"estimator, check",
_generate_checks_per_estimator(_yield_all_checks, _tested_estimators()),
)
def test_sklearn_integration(estimator, check):
xfail_checks = estimator._get_tags()["_xfail_checks"]
check_name = check.__name__ if hasattr(check, "__name__") else check.func.__name__
if xfail_checks and check_name in xfail_checks:
warnings.warn(xfail_checks[check_name], SkipTestWarning)
raise SkipTest
estimator.set_params(min_child_samples=1, min_data_in_bin=1)
name = estimator.__class__.__name__
check(name, estimator)
else:
@parametrize_with_checks(list(_tested_estimators()))
def test_sklearn_integration(estimator, check, request):
estimator.set_params(min_child_samples=1, min_data_in_bin=1)
check(estimator)
@pytest.mark.skipif(
sk_version >= parse_version("0.24"),
reason="Default constructible check included in common check from 0.24"
)
@pytest.mark.parametrize("estimator", list(_tested_estimators()))
def test_parameters_default_constructible(estimator):
name, Estimator = estimator.__class__.__name__, estimator.__class__
# Test that estimators are default-constructible
check_parameters_default_constructible(name, Estimator)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment