Commit 3f0061ca authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] parameters renaming for sklearn naming convention (#854)

* updated scikit-learn interface

* fixed better description

* updated set_params()

* removed backward compatibility

* removed excess lines

* replaced pop with setdefault

* added deprecated warnings

* added tests
parent 49412ba7
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import numpy as np import numpy as np
import warnings
from .basic import Dataset, LightGBMError from .basic import Dataset, LightGBMError
from .compat import (SKLEARN_INSTALLED, LGBMClassifierBase, LGBMDeprecated, from .compat import (SKLEARN_INSTALLED, LGBMClassifierBase, LGBMDeprecated,
...@@ -12,6 +13,11 @@ from .compat import (SKLEARN_INSTALLED, LGBMClassifierBase, LGBMDeprecated, ...@@ -12,6 +13,11 @@ from .compat import (SKLEARN_INSTALLED, LGBMClassifierBase, LGBMDeprecated,
from .engine import train from .engine import train
# DeprecationWarning is not shown by default, so let's create our own with higher level
class LGBMDeprecationWarning(UserWarning):
pass
def _objective_function_wrapper(func): def _objective_function_wrapper(func):
"""Decorate an objective function """Decorate an objective function
Note: for multi-class task, the y_pred is group by class_id first, then group by row_id Note: for multi-class task, the y_pred is group by class_id first, then group by row_id
...@@ -127,51 +133,52 @@ class LGBMModel(LGBMModelBase): ...@@ -127,51 +133,52 @@ class LGBMModel(LGBMModelBase):
subsample_for_bin=50000, objective=None, subsample_for_bin=50000, objective=None,
min_split_gain=0, min_child_weight=5, min_child_samples=10, min_split_gain=0, min_child_weight=5, min_child_samples=10,
subsample=1, subsample_freq=1, colsample_bytree=1, subsample=1, subsample_freq=1, colsample_bytree=1,
reg_alpha=0, reg_lambda=0, seed=0, nthread=-1, silent=True, **kwargs): reg_alpha=0, reg_lambda=0, random_state=0,
n_jobs=-1, silent=True, **kwargs):
""" """
Implementation of the Scikit-Learn API for LightGBM. Implementation of the Scikit-Learn API for LightGBM.
Parameters Parameters
---------- ----------
boosting_type : string boosting_type : string
gbdt, traditional Gradient Boosting Decision Tree gbdt, traditional Gradient Boosting Decision Tree.
dart, Dropouts meet Multiple Additive Regression Trees dart, Dropouts meet Multiple Additive Regression Trees.
num_leaves : int num_leaves : int
Maximum tree leaves for base learners. Maximum tree leaves for base learners.
max_depth : int max_depth : int
Maximum tree depth for base learners, -1 means no limit. Maximum tree depth for base learners, -1 means no limit.
learning_rate : float learning_rate : float
Boosting learning rate Boosting learning rate.
n_estimators : int n_estimators : int
Number of boosted trees to fit. Number of boosted trees to fit.
max_bin : int max_bin : int
Number of bucketed bin for feature values Number of bucketed bin for feature values.
subsample_for_bin : int subsample_for_bin : int
Number of samples for constructing bins. Number of samples for constructing bins.
objective : string or callable objective : string or callable
Specify the learning task and the corresponding learning objective or Specify the learning task and the corresponding learning objective or
a custom objective function to be used (see note below). a custom objective function to be used (see note below).
default: binary for LGBMClassifier, lambdarank for LGBMRanker default: binary for LGBMClassifier, lambdarank for LGBMRanker.
min_split_gain : float min_split_gain : float
Minimum loss reduction required to make a further partition on a leaf node of the tree. Minimum loss reduction required to make a further partition on a leaf node of the tree.
min_child_weight : int min_child_weight : int
Minimum sum of instance weight(hessian) needed in a child(leaf) Minimum sum of instance weight(hessian) needed in a child(leaf).
min_child_samples : int min_child_samples : int
Minimum number of data need in a child(leaf) Minimum number of data need in a child(leaf).
subsample : float subsample : float
Subsample ratio of the training instance. Subsample ratio of the training instance.
subsample_freq : int subsample_freq : int
frequence of subsample, <=0 means no enable frequence of subsample, <=0 means no enable.
colsample_bytree : float colsample_bytree : float
Subsample ratio of columns when constructing each tree. Subsample ratio of columns when constructing each tree.
reg_alpha : float reg_alpha : float
L1 regularization term on weights L1 regularization term on weights.
reg_lambda : float reg_lambda : float
L2 regularization term on weights L2 regularization term on weights.
seed : int random_state : int
Random number seed. Random number seed.
nthread : int n_jobs : int
Number of parallel threads Number of parallel threads.
silent : boolean silent : boolean
Whether to print messages while running boosting. Whether to print messages while running boosting.
**kwargs : other parameters **kwargs : other parameters
...@@ -186,15 +193,15 @@ class LGBMModel(LGBMModelBase): ...@@ -186,15 +193,15 @@ class LGBMModel(LGBMModelBase):
or ``objective(y_true, y_pred, group) -> grad, hess``: or ``objective(y_true, y_pred, group) -> grad, hess``:
y_true: array_like of shape [n_samples] y_true: array_like of shape [n_samples]
The target values The target values.
y_pred: array_like of shape [n_samples] or shape[n_samples * n_class] y_pred: array_like of shape [n_samples] or shape[n_samples * n_class]
The predicted values The predicted values.
group: array_like group: array_like
group/query data, used for ranking task group/query data, used for ranking task.
grad: array_like of shape [n_samples] or shape[n_samples * n_class] grad: array_like of shape [n_samples] or shape[n_samples * n_class]
The value of the gradient for each sample point. The value of the gradient for each sample point.
hess: array_like of shape [n_samples] or shape[n_samples * n_class] hess: array_like of shape [n_samples] or shape[n_samples * n_class]
The value of the second derivative for each sample point The value of the second derivative for each sample point.
for multi-class task, the y_pred is group by class_id first, then group by row_id for multi-class task, the y_pred is group by class_id first, then group by row_id
if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i] if you want to get i-th row y_pred in j-th class, the access way is y_pred[j*num_data+i]
...@@ -229,8 +236,8 @@ class LGBMModel(LGBMModelBase): ...@@ -229,8 +236,8 @@ class LGBMModel(LGBMModelBase):
self.colsample_bytree = colsample_bytree self.colsample_bytree = colsample_bytree
self.reg_alpha = reg_alpha self.reg_alpha = reg_alpha
self.reg_lambda = reg_lambda self.reg_lambda = reg_lambda
self.seed = seed self.random_state = random_state
self.nthread = nthread self.n_jobs = n_jobs
self.silent = silent self.silent = silent
self._Booster = None self._Booster = None
self.evals_result = None self.evals_result = None
...@@ -246,6 +253,12 @@ class LGBMModel(LGBMModelBase): ...@@ -246,6 +253,12 @@ class LGBMModel(LGBMModelBase):
def get_params(self, deep=True): def get_params(self, deep=True):
params = super(LGBMModel, self).get_params(deep=deep) params = super(LGBMModel, self).get_params(deep=deep)
params.update(self.other_params) params.update(self.other_params)
if 'seed' in params:
warnings.warn('The `seed` parameter is deprecated and will be removed in next version. '
'Please use `random_state` instead.', LGBMDeprecationWarning)
if 'nthread' in params:
warnings.warn('The `nthread` parameter is deprecated and will be removed in next version. '
'Please use `n_jobs` instead.', LGBMDeprecationWarning)
return params return params
# minor change to support `**kwargs` # minor change to support `**kwargs`
...@@ -333,6 +346,9 @@ class LGBMModel(LGBMModelBase): ...@@ -333,6 +346,9 @@ class LGBMModel(LGBMModelBase):
""" """
evals_result = {} evals_result = {}
params = self.get_params() params = self.get_params()
# sklearn interface has another naming convention
params.setdefault('seed', params.pop('random_state'))
params.setdefault('nthread', params.pop('n_jobs'))
# user can set verbose with kwargs, it has higher priority # user can set verbose with kwargs, it has higher priority
if 'verbose' not in params and self.silent: if 'verbose' not in params and self.silent:
params['verbose'] = -1 params['verbose'] = -1
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import math import math
import os import os
import unittest import unittest
import warnings
import lightgbm as lgb import lightgbm as lgb
import numpy as np import numpy as np
...@@ -158,3 +159,21 @@ class TestSklearn(unittest.TestCase): ...@@ -158,3 +159,21 @@ class TestSklearn(unittest.TestCase):
clf.fit(data.data, data.target) clf.fit(data.data, data.target)
importances = clf.feature_importances_ importances = clf.feature_importances_
self.assertEqual(len(importances), 4) self.assertEqual(len(importances), 4)
def test_sklearn_backward_compatibility(self):
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# Tests that `seed` is the same as `random_state`
clf_1 = lgb.sklearn.LGBMClassifier(seed=42, subsample=0.6, colsample_bytree=0.8)
clf_2 = lgb.sklearn.LGBMClassifier(random_state=42, subsample=0.6, colsample_bytree=0.8)
y_pred_1 = clf_1.fit(X_train, y_train).predict_proba(X_test)
y_pred_2 = clf_2.fit(X_train, y_train).predict_proba(X_test)
np.testing.assert_allclose(y_pred_1, y_pred_2)
# Tests that warnings were raised
with warnings.catch_warnings(record=True) as w:
clf_1.get_params()
clf_2.set_params(nthread=-1).fit(X_train, y_train)
self.assertEqual(len(w), 2)
self.assertTrue(issubclass(w[-1].category, Warning))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment