"src/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "c4536e227ef8b0482dde626fe39257c1b4a991b2"
Commit f3afe98b authored by aaiyer's avatar aaiyer Committed by Nikita Titov
Browse files

[python] Allow python sklearn interface's fit() to pass init_model to train() (#2447)

* allow python sklearn interface's fit() to pass init_model to train()

* Fix whitespace issues, and change ordering of parameters to be backward
compatible

* Formatting fixes

* allow python sklearn interface's fit() to pass init_model to train()

* Fix whitespace issues, and change ordering of parameters to be backward
compatible

* Formatting fixes

* Recognize LGBModel objects for init_model

* simplified condition

* updated docstring

* added test
parent 69c1c330
...@@ -376,7 +376,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -376,7 +376,8 @@ class LGBMModel(_LGBMModelBase):
eval_set=None, eval_names=None, eval_sample_weight=None, eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None, eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None, verbose=True, eval_metric=None, early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None): feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Build a gradient boosting model from the training set (X, y). """Build a gradient boosting model from the training set (X, y).
Parameters Parameters
...@@ -442,6 +443,8 @@ class LGBMModel(_LGBMModelBase): ...@@ -442,6 +443,8 @@ class LGBMModel(_LGBMModelBase):
callbacks : list of callback functions or None, optional (default=None) callbacks : list of callback functions or None, optional (default=None)
List of callback functions that are applied at each iteration. List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information. See Callbacks in Python API for more information.
init_model : string, Booster, LGBMModel or None, optional (default=None)
Filename of LightGBM model, Booster instance or LGBMModel instance used for continue training.
Returns Returns
------- -------
...@@ -593,13 +596,16 @@ class LGBMModel(_LGBMModelBase): ...@@ -593,13 +596,16 @@ class LGBMModel(_LGBMModelBase):
valid_weight, valid_init_score, valid_group, params) valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set) valid_sets.append(valid_set)
if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
self._Booster = train(params, train_set, self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets, valid_names=eval_names, self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=self._fobj, feval=feval, evals_result=evals_result, fobj=self._fobj, feval=feval,
verbose_eval=verbose, feature_name=feature_name, verbose_eval=verbose, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
callbacks=callbacks) callbacks=callbacks, init_model=init_model)
if evals_result: if evals_result:
self._evals_result = evals_result self._evals_result = evals_result
...@@ -731,7 +737,8 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase): ...@@ -731,7 +737,8 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
sample_weight=None, init_score=None, sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None, eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None, eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None): verbose=True, feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
super(LGBMRegressor, self).fit(X, y, sample_weight=sample_weight, super(LGBMRegressor, self).fit(X, y, sample_weight=sample_weight,
init_score=init_score, eval_set=eval_set, init_score=init_score, eval_set=eval_set,
...@@ -742,7 +749,7 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase): ...@@ -742,7 +749,7 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name, verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
callbacks=callbacks) callbacks=callbacks, init_model=init_model)
return self return self
_base_doc = LGBMModel.fit.__doc__ _base_doc = LGBMModel.fit.__doc__
...@@ -758,7 +765,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -758,7 +765,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
eval_set=None, eval_names=None, eval_sample_weight=None, eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_metric=None, eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True, early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None): feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y) _LGBMAssertAllFinite(y)
_LGBMCheckClassificationTargets(y) _LGBMCheckClassificationTargets(y)
...@@ -804,7 +812,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase): ...@@ -804,7 +812,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name, verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
callbacks=callbacks) callbacks=callbacks, init_model=init_model)
return self return self
fit.__doc__ = LGBMModel.fit.__doc__ fit.__doc__ = LGBMModel.fit.__doc__
...@@ -896,7 +904,8 @@ class LGBMRanker(LGBMModel): ...@@ -896,7 +904,8 @@ class LGBMRanker(LGBMModel):
eval_set=None, eval_names=None, eval_sample_weight=None, eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None, eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=[1], early_stopping_rounds=None, verbose=True, eval_at=[1], early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None): feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
# check group data # check group data
if group is None: if group is None:
...@@ -924,7 +933,7 @@ class LGBMRanker(LGBMModel): ...@@ -924,7 +933,7 @@ class LGBMRanker(LGBMModel):
early_stopping_rounds=early_stopping_rounds, early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name, verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature, categorical_feature=categorical_feature,
callbacks=callbacks) callbacks=callbacks, init_model=init_model)
return self return self
_base_doc = LGBMModel.fit.__doc__ _base_doc = LGBMModel.fit.__doc__
......
...@@ -790,3 +790,16 @@ class TestSklearn(unittest.TestCase): ...@@ -790,3 +790,16 @@ class TestSklearn(unittest.TestCase):
for metric in gbm.evals_result_[eval_set]: for metric in gbm.evals_result_[eval_set]:
np.testing.assert_allclose(gbm.evals_result_[eval_set][metric], np.testing.assert_allclose(gbm.evals_result_[eval_set][metric],
gbm_str.evals_result_[eval_set][metric]) gbm_str.evals_result_[eval_set][metric])
def test_continue_training_with_model(self):
X, y = load_digits(3, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
init_gbm = lgb.LGBMClassifier(n_estimators=5).fit(X_train, y_train, eval_set=(X_test, y_test),
verbose=False)
gbm = lgb.LGBMClassifier(n_estimators=5).fit(X_train, y_train, eval_set=(X_test, y_test),
verbose=False, init_model=init_gbm)
self.assertEqual(len(init_gbm.evals_result_['valid_0']['multi_logloss']),
len(gbm.evals_result_['valid_0']['multi_logloss']))
self.assertEqual(len(init_gbm.evals_result_['valid_0']['multi_logloss']), 5)
self.assertLess(gbm.evals_result_['valid_0']['multi_logloss'][-1],
init_gbm.evals_result_['valid_0']['multi_logloss'][-1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment