Unverified Commit bacb33d1 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors about _InnerPredictor (#5714)

parent 29796eee
...@@ -1375,9 +1375,9 @@ class Dataset: ...@@ -1375,9 +1375,9 @@ class Dataset:
self.categorical_feature = categorical_feature self.categorical_feature = categorical_feature
self.params = deepcopy(params) self.params = deepcopy(params)
self.free_raw_data = free_raw_data self.free_raw_data = free_raw_data
self.used_indices = None self.used_indices: Optional[List[int]] = None
self.need_slice = True self.need_slice = True
self._predictor = None self._predictor: Optional[_InnerPredictor] = None
self.pandas_categorical = None self.pandas_categorical = None
self.params_back_up = None self.params_back_up = None
self.monotone_constraints = None self.monotone_constraints = None
...@@ -1583,7 +1583,12 @@ class Dataset: ...@@ -1583,7 +1583,12 @@ class Dataset:
self.data = None self.data = None
return self return self
def _set_init_score_by_predictor(self, predictor, data, used_indices=None): def _set_init_score_by_predictor(
self,
predictor: Optional[_InnerPredictor],
data,
used_indices: Optional[List[int]]
):
data_has_header = False data_has_header = False
if isinstance(data, (str, Path)): if isinstance(data, (str, Path)):
# check data has header or not # check data has header or not
...@@ -1721,7 +1726,11 @@ class Dataset: ...@@ -1721,7 +1726,11 @@ class Dataset:
if isinstance(predictor, _InnerPredictor): if isinstance(predictor, _InnerPredictor):
if self._predictor is None and init_score is not None: if self._predictor is None and init_score is not None:
_log_warning("The init_score will be overridden by the prediction of init_model.") _log_warning("The init_score will be overridden by the prediction of init_model.")
self._set_init_score_by_predictor(predictor, data) self._set_init_score_by_predictor(
predictor=predictor,
data=data,
used_indices=None
)
elif init_score is not None: elif init_score is not None:
self.set_init_score(init_score) self.set_init_score(init_score)
elif predictor is not None: elif predictor is not None:
...@@ -2034,7 +2043,11 @@ class Dataset: ...@@ -2034,7 +2043,11 @@ class Dataset:
raise ValueError("Label should not be None.") raise ValueError("Label should not be None.")
if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor: if isinstance(self._predictor, _InnerPredictor) and self._predictor is not self.reference._predictor:
self.get_data() self.get_data()
self._set_init_score_by_predictor(self._predictor, self.data, used_indices) self._set_init_score_by_predictor(
predictor=self._predictor,
data=self.data,
used_indices=used_indices
)
else: else:
# create train # create train
self._lazy_init(self.data, label=self.label, self._lazy_init(self.data, label=self.label,
...@@ -2323,16 +2336,27 @@ class Dataset: ...@@ -2323,16 +2336,27 @@ class Dataset:
It is not recommended for user to call this function. It is not recommended for user to call this function.
Please use init_model argument in engine.train() or engine.cv() instead. Please use init_model argument in engine.train() or engine.cv() instead.
""" """
if predictor is self._predictor and (predictor is None or predictor.current_iteration() == self._predictor.current_iteration()): if predictor is None and self._predictor is None:
return self return self
elif isinstance(predictor, _InnerPredictor) and isinstance(self._predictor, _InnerPredictor):
if (predictor == self._predictor) and (predictor.current_iteration() == self._predictor.current_iteration()):
return self
if self.handle is None: if self.handle is None:
self._predictor = predictor self._predictor = predictor
elif self.data is not None: elif self.data is not None:
self._predictor = predictor self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.data) self._set_init_score_by_predictor(
predictor=self._predictor,
data=self.data,
used_indices=None
)
elif self.used_indices is not None and self.reference is not None and self.reference.data is not None: elif self.used_indices is not None and self.reference is not None and self.reference.data is not None:
self._predictor = predictor self._predictor = predictor
self._set_init_score_by_predictor(self._predictor, self.reference.data, self.used_indices) self._set_init_score_by_predictor(
predictor=self._predictor,
data=self.reference.data,
used_indices=self.used_indices
)
else: else:
raise LightGBMError("Cannot set predictor after freed raw data, " raise LightGBMError("Cannot set predictor after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.") "set free_raw_data=False when construct Dataset to avoid this.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment