"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "d4658fbb6fe943e9afee8c639934339cff38fd90"
Unverified Commit fc0f132f authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

[python] Fix continued train by reusing the same dataset (#2906)



* fix

* fix return

* fix test

* fix test

* fix predictor is none

* Apply suggestions from code review

* Update basic.py

* Update basic.py

* Apply suggestions from code review
Co-Authored-By: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 399b746b
...@@ -465,11 +465,7 @@ class _InnerPredictor(object): ...@@ -465,11 +465,7 @@ class _InnerPredictor(object):
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.num_class = out_num_class.value self.num_class = out_num_class.value
out_num_iterations = ctypes.c_int(0) self.num_total_iteration = self.current_iteration()
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_num_iterations)))
self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = None self.pandas_categorical = None
else: else:
raise TypeError('Need model_file or booster_handle to create a predictor') raise TypeError('Need model_file or booster_handle to create a predictor')
...@@ -726,6 +722,20 @@ class _InnerPredictor(object): ...@@ -726,6 +722,20 @@ class _InnerPredictor(object):
raise ValueError("Wrong length for predict results") raise ValueError("Wrong length for predict results")
return preds, nrow return preds, nrow
def current_iteration(self):
"""Get the index of the current iteration.
Returns
-------
cur_iter : int
The index of the current iteration.
"""
out_cur_iter = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetCurrentIteration(
self.handle,
ctypes.byref(out_cur_iter)))
return out_cur_iter.value
class Dataset(object): class Dataset(object):
"""Dataset in LightGBM.""" """Dataset in LightGBM."""
...@@ -842,27 +852,32 @@ class Dataset(object): ...@@ -842,27 +852,32 @@ class Dataset(object):
if isinstance(data, string_type): if isinstance(data, string_type):
# check data has header or not # check data has header or not
data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header")) data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header"))
init_score = predictor.predict(data,
raw_score=True,
data_has_header=data_has_header,
is_reshape=False)
num_data = self.num_data() num_data = self.num_data()
if used_indices is not None: if predictor is not None:
assert not self.need_slice init_score = predictor.predict(data,
if isinstance(data, string_type): raw_score=True,
sub_init_score = np.zeros(num_data * predictor.num_class, dtype=np.float32) data_has_header=data_has_header,
assert num_data == len(used_indices) is_reshape=False)
for i in range_(len(used_indices)): if used_indices is not None:
assert not self.need_slice
if isinstance(data, string_type):
sub_init_score = np.zeros(num_data * predictor.num_class, dtype=np.float32)
assert num_data == len(used_indices)
for i in range_(len(used_indices)):
for j in range_(predictor.num_class):
sub_init_score[i * predictor.num_class + j] = init_score[used_indices[i] * predictor.num_class + j]
init_score = sub_init_score
if predictor.num_class > 1:
# need to regroup init_score
new_init_score = np.zeros(init_score.size, dtype=np.float32)
for i in range_(num_data):
for j in range_(predictor.num_class): for j in range_(predictor.num_class):
sub_init_score[i * predictor.num_class + j] = init_score[used_indices[i] * predictor.num_class + j] new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j]
init_score = sub_init_score init_score = new_init_score
if predictor.num_class > 1: elif self.init_score is not None:
# need to regroup init_score init_score = np.zeros(self.init_score.shape, dtype=np.float32)
new_init_score = np.zeros(init_score.size, dtype=np.float32) else:
for i in range_(num_data): return self
for j in range_(predictor.num_class):
new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j]
init_score = new_init_score
self.set_init_score(init_score) self.set_init_score(init_score)
def _lazy_init(self, data, label=None, reference=None, def _lazy_init(self, data, label=None, reference=None,
...@@ -1381,16 +1396,20 @@ class Dataset(object): ...@@ -1381,16 +1396,20 @@ class Dataset(object):
It is not recommended for user to call this function. It is not recommended for user to call this function.
Please use init_model argument in engine.train() or engine.cv() instead. Please use init_model argument in engine.train() or engine.cv() instead.
""" """
if predictor is self._predictor: if predictor is self._predictor and (predictor is None or predictor.current_iteration() == self._predictor.current_iteration()):
return self return self
if self.data is not None or (self.used_indices is not None if self.handle is None:
and self.reference is not None
and self.reference.data is not None):
self._predictor = predictor self._predictor = predictor
return self._free_handle() elif self.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.data)
elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.reference.data, self.used_indices)
else: else:
raise LightGBMError("Cannot set predictor after freed raw data, " raise LightGBMError("Cannot set predictor after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.") "set free_raw_data=False when construct Dataset to avoid this.")
return self
def set_reference(self, reference): def set_reference(self, reference):
"""Set reference Dataset. """Set reference Dataset.
......
...@@ -613,6 +613,19 @@ class TestEngine(unittest.TestCase): ...@@ -613,6 +613,19 @@ class TestEngine(unittest.TestCase):
np.testing.assert_allclose(evals_result['valid_0']['l1'], evals_result['valid_0']['custom_mae']) np.testing.assert_allclose(evals_result['valid_0']['l1'], evals_result['valid_0']['custom_mae'])
os.remove(model_name) os.remove(model_name)
def test_continue_train_reused_dataset(self):
X, y = load_boston(True)
params = {
'objective': 'regression',
'verbose': -1
}
lgb_train = lgb.Dataset(X, y, free_raw_data=False)
init_gbm = lgb.train(params, lgb_train, num_boost_round=5)
init_gbm_2 = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm)
init_gbm_3 = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm_2)
gbm = lgb.train(params, lgb_train, num_boost_round=5, init_model=init_gbm_3)
self.assertEqual(gbm.current_iteration(), 20)
def test_continue_train_dart(self): def test_continue_train_dart(self):
X, y = load_boston(True) X, y = load_boston(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment