"python-package/vscode:/vscode.git/clone" did not exist on "742d72f8bb051105484fd5cca11620493ffb0b2b"
Unverified Commit fc0f132f authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] Fix continued train by reusing the same dataset (#2906)



* fix

* fix return

* fix test

* fix test

* fix predictor is none

* Apply suggestions from code review

* Update basic.py

* Update basic.py

* Apply suggestions from code review
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 399b746b
...@@ -465,11 +465,7 @@ class _InnerPredictor(object): ...@@ -465,11 +465,7 @@ class _InnerPredictor(object):
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.num_class = out_num_class.value self.num_class = out_num_class.value
out_num_iterations = ctypes.c_int(0) self.num_total_iteration = self.current_iteration()
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_num_iterations)))
self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = None self.pandas_categorical = None
else: else:
raise TypeError('Need model_file or booster_handle to create a predictor') raise TypeError('Need model_file or booster_handle to create a predictor')
...@@ -726,6 +722,20 @@ class _InnerPredictor(object): ...@@ -726,6 +722,20 @@ class _InnerPredictor(object):
raise ValueError("Wrong length for predict results") raise ValueError("Wrong length for predict results")
return preds, nrow return preds, nrow
def current_iteration(self):
"""Get the index of the current iteration.
Returns
-------
cur_iter : int
The index of the current iteration.
"""
out_cur_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_cur_iter)))
return out_cur_iter.value
class Dataset(object): class Dataset(object):
"""Dataset in LightGBM.""" """Dataset in LightGBM."""
...@@ -842,11 +852,12 @@ class Dataset(object): ...@@ -842,11 +852,12 @@ class Dataset(object):
if isinstance(data, string_type): if isinstance(data, string_type):
# check data has header or not # check data has header or not
data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header")) data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header"))
num_data = self.num_data()
if predictor is not None:
init_score = predictor.predict(data, init_score = predictor.predict(data,
raw_score=True, raw_score=True,
data_has_header=data_has_header, data_has_header=data_has_header,
is_reshape=False) is_reshape=False)
num_data = self.num_data()
if used_indices is not None: if used_indices is not None:
assert not self.need_slice assert not self.need_slice
if isinstance(data, string_type): if isinstance(data, string_type):
...@@ -863,6 +874,10 @@ class Dataset(object): ...@@ -863,6 +874,10 @@ class Dataset(object):
for j in range_(predictor.num_class): for j in range_(predictor.num_class):
new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j] new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j]
init_score = new_init_score init_score = new_init_score
elif self.init_score is not None:
init_score = np.zeros(self.init_score.shape, dtype=np.float32)
else:
return self
self.set_init_score(init_score) self.set_init_score(init_score)
def _lazy_init(self, data, label=None, reference=None, def _lazy_init(self, data, label=None, reference=None,
...@@ -1381,16 +1396,20 @@ class Dataset(object): ...@@ -1381,16 +1396,20 @@ class Dataset(object):
It is not recommended for user to call this function. It is not recommended for user to call this function.
Please use init_model argument in engine.train() or engine.cv() instead. Please use init_model argument in engine.train() or engine.cv() instead.
""" """
if predictor is self._predictor: if predictor is self._predictor and (predictor is None or predictor.current_iteration() == self._predictor.current_iteration()):
return self return self
if self.data is not None or (self.used_indices is not None if self.handle is None:
and self.reference is not None
and self.reference.data is not None):
self._predictor = predictor self._predictor = predictor
return self._free_handle() elif self.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.data)
elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.reference.data, self.used_indices)
else: else:
raise LightGBMError("Cannot set predictor after freed raw data, " raise LightGBMError("Cannot set predictor after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.") "set free_raw_data=False when construct Dataset to avoid this.")
return self
def set_reference(self, reference): def set_reference(self, reference):
"""Set reference Dataset. """Set reference Dataset.
......
...@@ -613,6 +613,19 @@ class TestEngine(unittest.TestCase): ...@@ -613,6 +613,19 @@ class TestEngine(unittest.TestCase):
np.testing.assert_allclose(evals_result['valid_0']['l1'], evals_result['valid_0']['custom_mae']) np.testing.assert_allclose(evals_result['valid_0']['l1'], evals_result['valid_0']['custom_mae'])
os.remove(model_name) os.remove(model_name)
def test_continue_train_reused_dataset(self):
X, y = load_boston(True)
params = {
'objective': 'regression',
'verbose': -1
}
lgb_train = lgb.Dataset(X, y, free_raw_data=False)
init_gbm = lgb.train(params, lgb_train, num_boost_round=5)
init_gbm_2 = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm)
init_gbm_3 = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm_2)
gbm = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm_3)
self.assertEqual(gbm.current_iteration(), 20)
def test_continue_train_dart(self): def test_continue_train_dart(self):
X, y = load_boston(True) X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment