Unverified Commit cf0a992e authored by Lukas Pfannschmidt's avatar Lukas Pfannschmidt Committed by GitHub
Browse files

[python] handle RandomState object in Scikit-learn Api (#2904)



* Add handling of RandomState object, which is standard for sklearn methods.

LightGBM expects an integer seed instead of an object.
If passed object is RandomState, we choose random integer based on its state to seed the underlying low level code.
While chosen random integer is only in the range between 1 and 1e10 I expect it to have enough entropy (?) to not matter in practice.

* Add RandomState object to random_state docstring.

* remove blank line

* Use property to handle setting random_state.
This enables setting cloned estimators with the set_params method in sklearn.

* Add docstring to attribute.

* Fix and simplify docstring.

* Add test case.

* Use maximal int for datatype in seed derivation.

* Replace random_state property with interfacing in fit method.
Derives int seed for C code only when fitting and keeps RandomState object as param.

* Adapt unit test to property change.

* Extended test case and docstring
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>

* Add more equality checks (feature importance, best iteration/score).

* Add equality comparison of boosters represented by strings.
Remove useless best_iteration_ comparison (we do not use early_stopping).

* fix whitespace

* Test if two subsequent fits produce different models

* Apply suggestions from code review
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent e32d34ab
...@@ -230,9 +230,11 @@ class LGBMModel(_LGBMModelBase): ...@@ -230,9 +230,11 @@ class LGBMModel(_LGBMModelBase):
L1 regularization term on weights. L1 regularization term on weights.
reg_lambda : float, optional (default=0.) reg_lambda : float, optional (default=0.)
L2 regularization term on weights. L2 regularization term on weights.
random_state : int or None, optional (default=None) random_state : int, RandomState object or None, optional (default=None)
Random number seed. Random number seed.
If None, default seeds in C++ code will be used. If int, this number is used to seed the C++ code.
If RandomState object (numpy), a random integer is picked based on its state to seed the C++ code.
If None, default seeds in C++ code are used.
n_jobs : int, optional (default=-1) n_jobs : int, optional (default=-1)
Number of parallel threads. Number of parallel threads.
silent : bool, optional (default=True) silent : bool, optional (default=True)
...@@ -503,6 +505,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -503,6 +505,8 @@ class LGBMModel(_LGBMModelBase):
params.pop('importance_type', None) params.pop('importance_type', None)
params.pop('n_estimators', None) params.pop('n_estimators', None)
params.pop('class_weight', None) params.pop('class_weight', None)
if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
for alias in _ConfigAliases.get('objective'): for alias in _ConfigAliases.get('objective'):
params.pop(alias, None) params.pop(alias, None)
if self._n_classes is not None and self._n_classes > 2: if self._n_classes is not None and self._n_classes > 2:
......
...@@ -227,6 +227,39 @@ class TestSklearn(unittest.TestCase): ...@@ -227,6 +227,39 @@ class TestSklearn(unittest.TestCase):
pred_pickle = gbm_pickle.predict(X_test) pred_pickle = gbm_pickle.predict(X_test)
np.testing.assert_allclose(pred_origin, pred_pickle) np.testing.assert_allclose(pred_origin, pred_pickle)
def test_random_state_object(self):
X, y = load_iris(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
state1 = np.random.RandomState(123)
state2 = np.random.RandomState(123)
clf1 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state1)
clf2 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state2)
# Test if random_state is properly stored
self.assertIs(clf1.random_state, state1)
self.assertIs(clf2.random_state, state2)
# Test if two random states produce identical models
clf1.fit(X_train, y_train)
clf2.fit(X_train, y_train)
y_pred1 = clf1.predict(X_test, raw_score=True)
y_pred2 = clf2.predict(X_test, raw_score=True)
np.testing.assert_allclose(y_pred1, y_pred2)
np.testing.assert_array_equal(clf1.feature_importances_, clf2.feature_importances_)
df1 = clf1.booster_.model_to_string(num_iteration=0)
df2 = clf2.booster_.model_to_string(num_iteration=0)
self.assertMultiLineEqual(df1, df2)
# Test if subsequent fits sample from random_state object and produce different models
clf1.fit(X_train, y_train)
y_pred1_refit = clf1.predict(X_test, raw_score=True)
df3 = clf1.booster_.model_to_string(num_iteration=0)
self.assertIs(clf1.random_state, state1)
self.assertIs(clf2.random_state, state2)
self.assertRaises(AssertionError,
np.testing.assert_allclose,
y_pred1, y_pred1_refit)
self.assertRaises(AssertionError,
self.assertMultiLineEqual,
df1, df3)
def test_feature_importances_single_leaf(self): def test_feature_importances_single_leaf(self):
data = load_iris() data = load_iris()
clf = lgb.LGBMClassifier(n_estimators=10) clf = lgb.LGBMClassifier(n_estimators=10)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment