Commit f53116af authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[ci][python] fixes according to scikit-learn 0.20 release (#1707)

* fixed FutureWarning about cv default value

* fixed according to new check_estimator API

* fixed joblib warning
parent 7825084f
...@@ -26,10 +26,11 @@ install: ...@@ -26,10 +26,11 @@ install:
default {$env:MINICONDA = """C:\Miniconda36-x64"""} default {$env:MINICONDA = """C:\Miniconda36-x64"""}
} }
- set PATH=%MINICONDA%;%MINICONDA%\Scripts;%PATH% - set PATH=%MINICONDA%;%MINICONDA%\Scripts;%PATH%
- set SKLEARN_SITE_JOBLIB=true # temp fix for joblib warning in examples
- ps: $env:LGB_VER = (Get-Content VERSION.txt).trim() - ps: $env:LGB_VER = (Get-Content VERSION.txt).trim()
- conda config --set always_yes yes --set changeps1 no - conda config --set always_yes yes --set changeps1 no
- conda update -q conda - conda update -q conda
- conda create -q -n test-env python=%PYTHON_VERSION% numpy nose scipy scikit-learn pandas matplotlib python-graphviz pytest - conda create -q -n test-env python=%PYTHON_VERSION% numpy nose scipy scikit-learn pandas matplotlib python-graphviz pytest joblib
- activate test-env - activate test-env
build_script: build_script:
......
...@@ -66,7 +66,7 @@ param_grid = { ...@@ -66,7 +66,7 @@ param_grid = {
'n_estimators': [20, 40] 'n_estimators': [20, 40]
} }
gbm = GridSearchCV(estimator, param_grid) gbm = GridSearchCV(estimator, param_grid, cv=3)
gbm.fit(X_train, y_train) gbm.fit(X_train, y_train)
......
...@@ -90,7 +90,8 @@ try: ...@@ -90,7 +90,8 @@ try:
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight from sklearn.utils.class_weight import compute_sample_weight
from sklearn.utils.multiclass import check_classification_targets from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_array, check_consistent_length from sklearn.utils.validation import (assert_all_finite, check_X_y,
check_array, check_consistent_length)
try: try:
from sklearn.model_selection import StratifiedKFold, GroupKFold from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.exceptions import NotFittedError from sklearn.exceptions import NotFittedError
...@@ -108,6 +109,7 @@ try: ...@@ -108,6 +109,7 @@ try:
_LGBMCheckXY = check_X_y _LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array _LGBMCheckArray = check_array
_LGBMCheckConsistentLength = check_consistent_length _LGBMCheckConsistentLength = check_consistent_length
_LGBMAssertAllFinite = assert_all_finite
_LGBMCheckClassificationTargets = check_classification_targets _LGBMCheckClassificationTargets = check_classification_targets
_LGBMComputeSampleWeight = compute_sample_weight _LGBMComputeSampleWeight = compute_sample_weight
except ImportError: except ImportError:
...@@ -122,6 +124,7 @@ except ImportError: ...@@ -122,6 +124,7 @@ except ImportError:
_LGBMCheckXY = None _LGBMCheckXY = None
_LGBMCheckArray = None _LGBMCheckArray = None
_LGBMCheckConsistentLength = None _LGBMCheckConsistentLength = None
_LGBMAssertAllFinite = None
_LGBMCheckClassificationTargets = None _LGBMCheckClassificationTargets = None
_LGBMComputeSampleWeight = None _LGBMComputeSampleWeight = None
......
...@@ -10,7 +10,7 @@ from .basic import Dataset, LightGBMError ...@@ -10,7 +10,7 @@ from .basic import Dataset, LightGBMError
from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase, from .compat import (SKLEARN_INSTALLED, _LGBMClassifierBase,
LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase, LGBMNotFittedError, _LGBMLabelEncoder, _LGBMModelBase,
_LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength, _LGBMRegressorBase, _LGBMCheckXY, _LGBMCheckArray, _LGBMCheckConsistentLength,
_LGBMCheckClassificationTargets, _LGBMComputeSampleWeight, _LGBMAssertAllFinite, _LGBMCheckClassificationTargets, _LGBMComputeSampleWeight,
argc_, range_, string_type, DataFrame, LGBMDeprecationWarning) argc_, range_, string_type, DataFrame, LGBMDeprecationWarning)
from .engine import train from .engine import train
...@@ -656,6 +656,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -656,6 +656,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
eval_class_weight=None, eval_init_score=None, eval_metric=None, eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None): feature_name='auto', categorical_feature='auto', callbacks=None):
_LGBMAssertAllFinite(y)
_LGBMCheckClassificationTargets(y) _LGBMCheckClassificationTargets(y)
self._le = _LGBMLabelEncoder().fit(y) self._le = _LGBMLabelEncoder().fit(y)
_y = self._le.transform(y) _y = self._le.transform(y)
......
...@@ -6,19 +6,16 @@ import unittest ...@@ -6,19 +6,16 @@ import unittest
import lightgbm as lgb import lightgbm as lgb
import numpy as np import numpy as np
from sklearn import __version__ as sk_version
from sklearn.base import clone from sklearn.base import clone
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris, load_svmlight_file) load_iris, load_svmlight_file)
from sklearn.exceptions import SkipTestWarning
from sklearn.externals import joblib from sklearn.externals import joblib
from sklearn.metrics import log_loss, mean_squared_error from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest, from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest,
check_parameters_default_constructible) check_parameters_default_constructible)
try:
from sklearn.utils.estimator_checks import check_no_fit_attributes_set_in_init
sklearn_at_least_019 = True
except ImportError:
sklearn_at_least_019 = False
def multi_error(y_true, y_pred): def multi_error(y_true, y_pred):
...@@ -180,17 +177,17 @@ class TestSklearn(unittest.TestCase): ...@@ -180,17 +177,17 @@ class TestSklearn(unittest.TestCase):
self.assertNotEqual(importance_split_top1, importance_gain_top1) self.assertNotEqual(importance_split_top1, importance_gain_top1)
# sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1 # sklearn <0.19 cannot accept instance, but many tests could be passed only with min_data=1 and min_data_in_bin=1
@unittest.skipIf(not sklearn_at_least_019, 'scikit-learn version is less than 0.19') @unittest.skipIf(sk_version < '0.19.0', 'scikit-learn version is less than 0.19')
def test_sklearn_integration(self): def test_sklearn_integration(self):
# we cannot use `check_estimator` directly since there is no skip test mechanism # we cannot use `check_estimator` directly since there is no skip test mechanism
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier), for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier),
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)): (lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)):
check_parameters_default_constructible(name, estimator) check_parameters_default_constructible(name, estimator)
check_no_fit_attributes_set_in_init(name, estimator)
# we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833) # we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833)
estimator = estimator(min_child_samples=1, min_data_in_bin=1) estimator = estimator(min_child_samples=1, min_data_in_bin=1)
for check in _yield_all_checks(name, estimator): for check in _yield_all_checks(name, estimator):
if check.__name__ == 'check_estimators_nan_inf': check_name = check.func.__name__ if hasattr(check, 'func') else check.__name__
if check_name == 'check_estimators_nan_inf':
continue # skip test because LightGBM deals with nan continue # skip test because LightGBM deals with nan
try: try:
check(name, estimator) check(name, estimator)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment