Unverified Commit fc0f132f authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] Fix continued train by reusing the same dataset (#2906)



* fix

* fix return

* fix test

* fix test

* fix predictor is none

* Apply suggestions from code review

* Update basic.py

* Update basic.py

* Apply suggestions from code review
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 399b746b
......@@ -465,11 +465,7 @@ class _InnerPredictor(object):
self.handle,
ctypes.byref(out_num_class)))
self.num_class = out_num_class.value
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_num_iterations)))
self.num_total_iteration = out_num_iterations.value
self.num_total_iteration = self.current_iteration()
self.pandas_categorical = None
else:
raise TypeError('Need model_file or booster_handle to create a predictor')
......@@ -726,6 +722,20 @@ class _InnerPredictor(object):
raise ValueError("Wrong length for predict results")
return preds, nrow
def current_iteration(self):
"""Get the index of the current iteration.
Returns
-------
cur_iter : int
The index of the current iteration.
"""
out_cur_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_cur_iter)))
return out_cur_iter.value
class Dataset(object):
"""Dataset in LightGBM."""
......@@ -842,11 +852,12 @@ class Dataset(object):
if isinstance(data, string_type):
# check data has header or not
data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header"))
num_data = self.num_data()
if predictor is not None:
init_score = predictor.predict(data,
raw_score=True,
data_has_header=data_has_header,
is_reshape=False)
num_data = self.num_data()
if used_indices is not None:
assert not self.need_slice
if isinstance(data, string_type):
......@@ -863,6 +874,10 @@ class Dataset(object):
for j in range_(predictor.num_class):
new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j]
init_score = new_init_score
elif self.init_score is not None:
init_score = np.zeros(self.init_score.shape, dtype=np.float32)
else:
return self
self.set_init_score(init_score)
def _lazy_init(self, data, label=None, reference=None,
......@@ -1381,16 +1396,20 @@ class Dataset(object):
It is not recommended for user to call this function.
Please use init_model argument in engine.train() or engine.cv() instead.
"""
if predictor is self._predictor:
if predictor is self._predictor and (predictor is None or predictor.current_iteration() == self._predictor.current_iteration()):
return self
if self.data is not None or (self.used_indices is not None
and self.reference is not None
and self.reference.data is not None):
if self.handle is None:
self._predictor = predictor
return self._free_handle()
elif self.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.data)
elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.reference.data, self.used_indices)
else:
raise LightGBMError("Cannot set predictor after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.")
return self
def set_reference(self, reference):
"""Set reference Dataset.
......
......@@ -613,6 +613,19 @@ class TestEngine(unittest.TestCase):
np.testing.assert_allclose(evals_result['valid_0']['l1'], evals_result['valid_0']['custom_mae'])
os.remove(model_name)
def test_continue_train_reused_dataset(self):
X, y = load_boston(True)
params = {
'objective': 'regression',
'verbose': -1
}
lgb_train = lgb.Dataset(X, y, free_raw_data=False)
init_gbm = lgb.train(params, lgb_train, num_boost_round=5)
init_gbm_2 = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm)
init_gbm_3 = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm_2)
gbm = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm_3)
self.assertEqual(gbm.current_iteration(), 20)
def test_continue_train_dart(self):
X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment