Unverified Commit bacb33d1 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors about _InnerPredictor (#5714)

parent 29796eee
......@@ -1375,9 +1375,9 @@ class Dataset:
self.categorical_feature = categorical_feature
self.params = deepcopy(params)
self.free_raw_data = free_raw_data
self.used_indices = None
self.used_indices: Optional[List[int]] = None
self.need_slice = True
self._predictor = None
self._predictor: Optional[_InnerPredictor] = None
self.pandas_categorical = None
self.params_back_up = None
self.monotone_constraints = None
......@@ -1583,7 +1583,12 @@ class Dataset:
self.data = None
return self
def _set_init_score_by_predictor(self, predictor, data, used_indices=None):
def _set_init_score_by_predictor(
self,
predictor: Optional[_InnerPredictor],
data,
used_indices: Optional[List[int]]
):
data_has_header = False
if isinstance(data, (str, Path)):
# check data has header or not
......@@ -1721,7 +1726,11 @@ class Dataset:
if isinstance(predictor, _InnerPredictor):
if self._predictor is None and init_score is not None:
_log_warning("The init_score will be overridden by the prediction of init_model.")
self._set_init_score_by_predictor(predictor, data)
self._set_init_score_by_predictor(
predictor=predictor,
data=data,
used_indices=None
)
elif init_score is not None:
self.set_init_score(init_score)
elif predictor is not None:
......@@ -2034,7 +2043,11 @@ class Dataset:
raise ValueError("Label should not be None.")
if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor:
self.get_data()
self._set_init_score_by_predictor(self._predictor, self.data, used_indices)
self._set_init_score_by_predictor(
predictor=self._predictor,
data=self.data,
used_indices=used_indices
)
else:
# create train
self._lazy_init(self.data, label=self.label,
......@@ -2323,16 +2336,27 @@ class Dataset:
It is not recommended for user to call this function.
Please use init_model argument in engine.train() or engine.cv() instead.
"""
if predictor is self._predictor and (predictor is None or predictor.current_iteration() == self._predictor.current_iteration()):
if predictor is None and self._predictor is None:
return self
elif isinstance(predictor, _InnerPredictor) and isinstance(self._predictor, _InnerPredictor):
if (predictor == self._predictor) and (predictor.current_iteration() == self._predictor.current_iteration()):
return self
if self.handle is None:
self._predictor = predictor
elif self.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.data)
self._set_init_score_by_predictor(
predictor=self._predictor,
data=self.data,
used_indices=None
)
elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.reference.data, self.used_indices)
self._set_init_score_by_predictor(
predictor=self._predictor,
data=self.reference.data,
used_indices=self.used_indices
)
else:
raise LightGBMError("Cannot set predictor after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment