Commit 868b6242 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

[python] copy old Booster's attributes to new one in refit method (#1699)

* copy old Booster's attributes to new one in refit method

* fixed according to review comment

* raise error in case of null objective
parent 3f2f24a4
...@@ -1748,7 +1748,7 @@ class Booster(object): ...@@ -1748,7 +1748,7 @@ class Booster(object):
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
if fobj is None: if fobj is None:
if self.__set_objective_to_none: if self.__set_objective_to_none:
raise ValueError('Cannot update due to null objective function.') raise LightGBMError('Cannot update due to null objective function.')
_safe_call(_LIB.LGBM_BoosterUpdateOneIter( _safe_call(_LIB.LGBM_BoosterUpdateOneIter(
self.handle, self.handle,
ctypes.byref(is_finished))) ctypes.byref(is_finished)))
...@@ -2167,6 +2167,8 @@ class Booster(object): ...@@ -2167,6 +2167,8 @@ class Booster(object):
result : Booster result : Booster
Refitted Booster. Refitted Booster.
""" """
if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.')
predictor = self._to_predictor(kwargs) predictor = self._to_predictor(kwargs)
leaf_preds = predictor.predict(data, -1, pred_leaf=True) leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape nrow, ncol = leaf_preds.shape
...@@ -2183,6 +2185,8 @@ class Booster(object): ...@@ -2183,6 +2185,8 @@ class Booster(object):
ptr_data, ptr_data,
ctypes.c_int(nrow), ctypes.c_int(nrow),
ctypes.c_int(ncol))) ctypes.c_int(ncol)))
new_booster.network = self.network
new_booster.__attr = self.__attr.copy()
return new_booster return new_booster
def get_leaf_output(self, tree_id, leaf_id): def get_leaf_output(self, tree_id, leaf_id):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment