Commit f3afe98b authored by aaiyer's avatar aaiyer Committed by Nikita Titov
Browse files

[python] Allow python sklearn interface's fit() to pass init_model to train() (#2447)

* allow python sklearn interface's fit() to pass init_model to train()

* Fix whitespace issues, and change ordering of parameters to be backward
compatible

* Formatting fixes

* allow python sklearn interface's fit() to pass init_model to train()

* Fix whitespace issues, and change ordering of parameters to be backward
compatible

* Formatting fixes

* Recognize LGBModel objects for init_model

* simplified condition

* updated docstring

* added test
parent 69c1c330
......@@ -376,7 +376,8 @@ class LGBMModel(_LGBMModelBase):
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_group=None,
eval_metric=None, early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Build a gradient boosting model from the training set (X, y).
Parameters
......@@ -442,6 +443,8 @@ class LGBMModel(_LGBMModelBase):
callbacks : list of callback functions or None, optional (default=None)
List of callback functions that are applied at each iteration.
See Callbacks in Python API for more information.
init_model : string, Booster, LGBMModel or None, optional (default=None)
Filename of LightGBM model, Booster instance or LGBMModel instance used for continue training.
Returns
-------
......@@ -593,13 +596,16 @@ class LGBMModel(_LGBMModelBase):
valid_weight, valid_init_score, valid_group, params)
valid_sets.append(valid_set)
if isinstance(init_model, LGBMModel):
init_model = init_model.booster_
self._Booster = train(params, train_set,
self.n_estimators, valid_sets=valid_sets, valid_names=eval_names,
early_stopping_rounds=early_stopping_rounds,
evals_result=evals_result, fobj=self._fobj, feval=feval,
verbose_eval=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)
if evals_result:
self._evals_result = evals_result
......@@ -731,7 +737,8 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
sample_weight=None, init_score=None,
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_metric=None, early_stopping_rounds=None,
verbose=True, feature_name='auto', categorical_feature='auto', callbacks=None):
verbose=True, feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
super(LGBMRegressor, self).fit(X, y, sample_weight=sample_weight,
init_score=init_score, eval_set=eval_set,
......@@ -742,7 +749,7 @@ class LGBMRegressor(LGBMModel, _LGBMRegressorBase):
early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)
return self
_base_doc = LGBMModel.fit.__doc__
......@@ -758,7 +765,8 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_class_weight=None, eval_init_score=None, eval_metric=None,
early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
_LGBMAssertAllFinite(y)
_LGBMCheckClassificationTargets(y)
......@@ -804,7 +812,7 @@ class LGBMClassifier(LGBMModel, _LGBMClassifierBase):
early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)
return self
fit.__doc__ = LGBMModel.fit.__doc__
......@@ -896,7 +904,8 @@ class LGBMRanker(LGBMModel):
eval_set=None, eval_names=None, eval_sample_weight=None,
eval_init_score=None, eval_group=None, eval_metric=None,
eval_at=[1], early_stopping_rounds=None, verbose=True,
feature_name='auto', categorical_feature='auto', callbacks=None):
feature_name='auto', categorical_feature='auto',
callbacks=None, init_model=None):
"""Docstring is inherited from the LGBMModel."""
# check group data
if group is None:
......@@ -924,7 +933,7 @@ class LGBMRanker(LGBMModel):
early_stopping_rounds=early_stopping_rounds,
verbose=verbose, feature_name=feature_name,
categorical_feature=categorical_feature,
callbacks=callbacks)
callbacks=callbacks, init_model=init_model)
return self
_base_doc = LGBMModel.fit.__doc__
......
......@@ -790,3 +790,16 @@ class TestSklearn(unittest.TestCase):
for metric in gbm.evals_result_[eval_set]:
np.testing.assert_allclose(gbm.evals_result_[eval_set][metric],
gbm_str.evals_result_[eval_set][metric])
def test_continue_training_with_model(self):
X, y = load_digits(3, True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
init_gbm = lgb.LGBMClassifier(n_estimators=5).fit(X_train, y_train, eval_set=(X_test, y_test),
verbose=False)
gbm = lgb.LGBMClassifier(n_estimators=5).fit(X_train, y_train, eval_set=(X_test, y_test),
verbose=False, init_model=init_gbm)
self.assertEqual(len(init_gbm.evals_result_['valid_0']['multi_logloss']),
len(gbm.evals_result_['valid_0']['multi_logloss']))
self.assertEqual(len(init_gbm.evals_result_['valid_0']['multi_logloss']), 5)
self.assertLess(gbm.evals_result_['valid_0']['multi_logloss'][-1],
init_gbm.evals_result_['valid_0']['multi_logloss'][-1])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment