Unverified Commit d0dfceec authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[ci] [python-package] fix mypy errors about Dataset._set_init_score_by_predictor() (#5850)

parent 8670013d
......@@ -1744,9 +1744,11 @@ class Dataset:
data_has_header = any(self.params.get(alias, False) for alias in _ConfigAliases.get("header"))
num_data = self.num_data()
if predictor is not None:
init_score = predictor.predict(data,
raw_score=True,
data_has_header=data_has_header)
init_score: Union[np.ndarray, scipy.sparse.spmatrix] = predictor.predict(
data=data,
raw_score=True,
data_has_header=data_has_header
)
init_score = init_score.ravel()
if used_indices is not None:
assert not self._need_slice
......@@ -1765,7 +1767,7 @@ class Dataset:
new_init_score[j * num_data + i] = init_score[i * predictor.num_class + j]
init_score = new_init_score
elif self.init_score is not None:
init_score = np.zeros(self.init_score.shape, dtype=np.float64)
init_score = np.full_like(self.init_score, fill_value=0.0, dtype=np.float64)
else:
return self
self.set_init_score(init_score)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment