Commit 015c8fff authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] improved sklearn interface (#870)

* improved sklearn interface; added sklearns' tests

* moved best_score into the if statement

* improved docstrings; simplified LGBMCheckConsistentLength

* fixed typo

* pylint

* updated example

* fixed Ranker interface

* added missed boosting_type

* fixed more comfortable autocomplete without unused objects

* removed check for None of eval_at

* fixed according to review

* fixed typo

* added description of fit return type

* dictionary->dict for short

* markdown cleanup
parent 898c88d1
...@@ -40,7 +40,7 @@ if [[ ${TASK} == "if-else" ]]; then ...@@ -40,7 +40,7 @@ if [[ ${TASK} == "if-else" ]]; then
exit 0 exit 0
fi fi
conda install --yes numpy scipy scikit-learn pandas matplotlib conda install --yes numpy nose scipy scikit-learn pandas matplotlib
pip install pytest pip install pytest
if [[ ${TASK} == "sdist" ]]; then if [[ ${TASK} == "sdist" ]]; then
......
...@@ -21,7 +21,7 @@ test_script: ...@@ -21,7 +21,7 @@ test_script:
- conda config --set always_yes yes --set changeps1 no - conda config --set always_yes yes --set changeps1 no
- conda update -q conda - conda update -q conda
- conda info -a - conda info -a
- conda install --yes numpy scipy scikit-learn pandas matplotlib - conda install --yes numpy nose scipy scikit-learn pandas matplotlib
- pip install pep8 pytest - pip install pep8 pytest
- pytest tests/c_api_test/test_.py - pytest tests/c_api_test/test_.py
- "set /p LGB_VER=< VERSION.txt" - "set /p LGB_VER=< VERSION.txt"
......
...@@ -28,7 +28,7 @@ gbm.fit(X_train, y_train, ...@@ -28,7 +28,7 @@ gbm.fit(X_train, y_train,
print('Start predicting...') print('Start predicting...')
# predict # predict
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration) y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration_)
# eval # eval
print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5) print('The rmse of prediction is:', mean_squared_error(y_test, y_pred) ** 0.5)
......
...@@ -64,23 +64,38 @@ try: ...@@ -64,23 +64,38 @@ try:
from sklearn.base import RegressorMixin, ClassifierMixin from sklearn.base import RegressorMixin, ClassifierMixin
from sklearn.preprocessing import LabelEncoder from sklearn.preprocessing import LabelEncoder
from sklearn.utils import deprecated from sklearn.utils import deprecated
from sklearn.utils.multiclass import check_classification_targets
from sklearn.utils.validation import check_X_y, check_array, check_consistent_length
try: try:
from sklearn.model_selection import StratifiedKFold, GroupKFold from sklearn.model_selection import StratifiedKFold, GroupKFold
from sklearn.exceptions import NotFittedError
except ImportError: except ImportError:
from sklearn.cross_validation import StratifiedKFold, GroupKFold from sklearn.cross_validation import StratifiedKFold, GroupKFold
from sklearn.utils.validation import NotFittedError
SKLEARN_INSTALLED = True SKLEARN_INSTALLED = True
LGBMModelBase = BaseEstimator _LGBMModelBase = BaseEstimator
LGBMRegressorBase = RegressorMixin _LGBMRegressorBase = RegressorMixin
LGBMClassifierBase = ClassifierMixin _LGBMClassifierBase = ClassifierMixin
LGBMLabelEncoder = LabelEncoder _LGBMLabelEncoder = LabelEncoder
LGBMDeprecated = deprecated LGBMDeprecated = deprecated
LGBMStratifiedKFold = StratifiedKFold LGBMNotFittedError = NotFittedError
LGBMGroupKFold = GroupKFold _LGBMStratifiedKFold = StratifiedKFold
_LGBMGroupKFold = GroupKFold
_LGBMCheckXY = check_X_y
_LGBMCheckArray = check_array
_LGBMCheckConsistentLength = check_consistent_length
_LGBMCheckClassificationTargets = check_classification_targets
except ImportError: except ImportError:
SKLEARN_INSTALLED = False SKLEARN_INSTALLED = False
LGBMModelBase = object _LGBMModelBase = object
LGBMClassifierBase = object _LGBMClassifierBase = object
LGBMRegressorBase = object _LGBMRegressorBase = object
LGBMLabelEncoder = None _LGBMLabelEncoder = None
LGBMStratifiedKFold = None # LGBMDeprecated = None Don't uncomment it because it causes error without installed sklearn
LGBMGroupKFold = None LGBMNotFittedError = ValueError
_LGBMStratifiedKFold = None
_LGBMGroupKFold = None
_LGBMCheckXY = None
_LGBMCheckArray = None
_LGBMCheckConsistentLength = None
_LGBMCheckClassificationTargets = None
...@@ -11,7 +11,7 @@ import numpy as np ...@@ -11,7 +11,7 @@ import numpy as np
from . import callback from . import callback
from .basic import Booster, Dataset, LightGBMError, _InnerPredictor from .basic import Booster, Dataset, LightGBMError, _InnerPredictor
from .compat import (SKLEARN_INSTALLED, LGBMGroupKFold, LGBMStratifiedKFold, from .compat import (SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold,
integer_types, range_, string_type) integer_types, range_, string_type)
...@@ -264,12 +264,12 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi ...@@ -264,12 +264,12 @@ def _make_n_folds(full_data, folds, nfold, params, seed, fpreproc=None, stratifi
# lambdarank task, split according to groups # lambdarank task, split according to groups
group_info = full_data.get_group().astype(int) group_info = full_data.get_group().astype(int)
flatted_group = np.repeat(range(len(group_info)), repeats=group_info) flatted_group = np.repeat(range(len(group_info)), repeats=group_info)
group_kfold = LGBMGroupKFold(n_splits=nfold) group_kfold = _LGBMGroupKFold(n_splits=nfold)
folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group) folds = group_kfold.split(X=np.zeros(num_data), groups=flatted_group)
elif stratified: elif stratified:
if not SKLEARN_INSTALLED: if not SKLEARN_INSTALLED:
raise LightGBMError('Scikit-learn is required for stratified cv.') raise LightGBMError('Scikit-learn is required for stratified cv.')
skf = LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed) skf = _LGBMStratifiedKFold(n_splits=nfold, shuffle=shuffle, random_state=seed)
folds = skf.split(X=np.zeros(num_data), y=full_data.get_label()) folds = skf.split(X=np.zeros(num_data), y=full_data.get_label())
else: else:
if shuffle: if shuffle:
......
This diff is collapsed.
...@@ -13,6 +13,9 @@ from sklearn.datasets import (load_boston, load_breast_cancer, load_digits, ...@@ -13,6 +13,9 @@ from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
from sklearn.externals import joblib from sklearn.externals import joblib
from sklearn.metrics import log_loss, mean_squared_error from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import GridSearchCV, train_test_split from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.utils.estimator_checks import (_yield_all_checks, SkipTest,
check_parameters_default_constructible,
check_no_fit_attributes_set_in_init)
def multi_error(y_true, y_pred): def multi_error(y_true, y_pred):
...@@ -32,7 +35,7 @@ class TestSklearn(unittest.TestCase): ...@@ -32,7 +35,7 @@ class TestSklearn(unittest.TestCase):
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False) gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False)
ret = log_loss(y_test, gbm.predict_proba(X_test)) ret = log_loss(y_test, gbm.predict_proba(X_test))
self.assertLess(ret, 0.15) self.assertLess(ret, 0.15)
self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['binary_logloss'][gbm.best_iteration - 1], places=5) self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['binary_logloss'][gbm.best_iteration_ - 1], places=5)
def test_regreesion(self): def test_regreesion(self):
X, y = load_boston(True) X, y = load_boston(True)
...@@ -41,7 +44,7 @@ class TestSklearn(unittest.TestCase): ...@@ -41,7 +44,7 @@ class TestSklearn(unittest.TestCase):
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False) gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False)
ret = mean_squared_error(y_test, gbm.predict(X_test)) ret = mean_squared_error(y_test, gbm.predict(X_test))
self.assertLess(ret, 16) self.assertLess(ret, 16)
self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['l2'][gbm.best_iteration - 1], places=5) self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['l2'][gbm.best_iteration_ - 1], places=5)
def test_multiclass(self): def test_multiclass(self):
X, y = load_digits(10, True) X, y = load_digits(10, True)
...@@ -51,7 +54,7 @@ class TestSklearn(unittest.TestCase): ...@@ -51,7 +54,7 @@ class TestSklearn(unittest.TestCase):
ret = multi_error(y_test, gbm.predict(X_test)) ret = multi_error(y_test, gbm.predict(X_test))
self.assertLess(ret, 0.2) self.assertLess(ret, 0.2)
ret = multi_logloss(y_test, gbm.predict_proba(X_test)) ret = multi_logloss(y_test, gbm.predict_proba(X_test))
self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['multi_logloss'][gbm.best_iteration - 1], places=5) self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['multi_logloss'][gbm.best_iteration_ - 1], places=5)
def test_lambdarank(self): def test_lambdarank(self):
X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train')) X_train, y_train = load_svmlight_file(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../examples/lambdarank/rank.train'))
...@@ -74,7 +77,7 @@ class TestSklearn(unittest.TestCase): ...@@ -74,7 +77,7 @@ class TestSklearn(unittest.TestCase):
gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False) gbm.fit(X_train, y_train, eval_set=[(X_test, y_test)], early_stopping_rounds=5, verbose=False)
ret = mean_squared_error(y_test, gbm.predict(X_test)) ret = mean_squared_error(y_test, gbm.predict(X_test))
self.assertLess(ret, 100) self.assertLess(ret, 100)
self.assertAlmostEqual(ret, gbm.evals_result['valid_0']['l2'][gbm.best_iteration - 1], places=5) self.assertAlmostEqual(ret, gbm.evals_result_['valid_0']['l2'][gbm.best_iteration_ - 1], places=5)
def test_binary_classification_with_custom_objective(self): def test_binary_classification_with_custom_objective(self):
def logregobj(y_true, y_pred): def logregobj(y_true, y_pred):
...@@ -177,3 +180,19 @@ class TestSklearn(unittest.TestCase): ...@@ -177,3 +180,19 @@ class TestSklearn(unittest.TestCase):
clf_2.set_params(nthread=-1).fit(X_train, y_train) clf_2.set_params(nthread=-1).fit(X_train, y_train)
self.assertEqual(len(w), 2) self.assertEqual(len(w), 2)
self.assertTrue(issubclass(w[-1].category, Warning)) self.assertTrue(issubclass(w[-1].category, Warning))
def test_sklearn_integration(self):
# we cannot use `check_estimator` directly since there is no skip test mechanism
for name, estimator in ((lgb.sklearn.LGBMClassifier.__name__, lgb.sklearn.LGBMClassifier),
(lgb.sklearn.LGBMRegressor.__name__, lgb.sklearn.LGBMRegressor)):
check_parameters_default_constructible(name, estimator)
check_no_fit_attributes_set_in_init(name, estimator)
# we cannot leave default params (see https://github.com/Microsoft/LightGBM/issues/833)
estimator = estimator(min_data=1, min_data_in_bin=1)
for check in _yield_all_checks(name, estimator):
if check.__name__ == 'check_estimators_nan_inf':
continue # skip test because LightGBM deals with nan
try:
check(name, estimator)
except SkipTest as message:
warnings.warn(message, SkipTestWarning)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment