Unverified Commit 78207462 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] remove is_reshape argument in Booster.predict (fixes #5115) (#5117)

parent db089854
...@@ -752,8 +752,7 @@ class _InnerPredictor: ...@@ -752,8 +752,7 @@ class _InnerPredictor:
return this return this
def predict(self, data, start_iteration=0, num_iteration=-1, def predict(self, data, start_iteration=0, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False):
is_reshape=True):
"""Predict logic. """Predict logic.
Parameters Parameters
...@@ -774,8 +773,6 @@ class _InnerPredictor: ...@@ -774,8 +773,6 @@ class _InnerPredictor:
data_has_header : bool, optional (default=False) data_has_header : bool, optional (default=False)
Whether data has header. Whether data has header.
Used only for txt data. Used only for txt data.
is_reshape : bool, optional (default=True)
Whether to reshape to (nrow, ncol).
Returns Returns
------- -------
...@@ -832,7 +829,7 @@ class _InnerPredictor: ...@@ -832,7 +829,7 @@ class _InnerPredictor:
if pred_leaf: if pred_leaf:
preds = preds.astype(np.int32) preds = preds.astype(np.int32)
is_sparse = scipy.sparse.issparse(preds) or isinstance(preds, list) is_sparse = scipy.sparse.issparse(preds) or isinstance(preds, list)
if is_reshape and not is_sparse and preds.size != nrow: if not is_sparse and preds.size != nrow:
if preds.size % nrow == 0: if preds.size % nrow == 0:
preds = preds.reshape(nrow, -1) preds = preds.reshape(nrow, -1)
else: else:
...@@ -1403,8 +1400,8 @@ class Dataset: ...@@ -1403,8 +1400,8 @@ class Dataset:
if predictor is not None: if predictor is not None:
init_score = predictor.predict(data, init_score = predictor.predict(data,
raw_score=True, raw_score=True,
data_has_header=data_has_header, data_has_header=data_has_header)
is_reshape=False) init_score = init_score.ravel()
if used_indices is not None: if used_indices is not None:
assert not self.need_slice assert not self.need_slice
if isinstance(data, (str, Path)): if isinstance(data, (str, Path)):
...@@ -3489,7 +3486,7 @@ class Booster: ...@@ -3489,7 +3486,7 @@ class Booster:
def predict(self, data, start_iteration=0, num_iteration=None, def predict(self, data, start_iteration=0, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False, raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, **kwargs): data_has_header=False, **kwargs):
"""Make a prediction. """Make a prediction.
Parameters Parameters
...@@ -3523,8 +3520,6 @@ class Booster: ...@@ -3523,8 +3520,6 @@ class Booster:
data_has_header : bool, optional (default=False) data_has_header : bool, optional (default=False)
Whether the data has header. Whether the data has header.
Used only if data is str. Used only if data is str.
is_reshape : bool, optional (default=True)
If True, result is reshaped to [nrow, ncol].
**kwargs **kwargs
Other parameters for the prediction. Other parameters for the prediction.
...@@ -3542,7 +3537,7 @@ class Booster: ...@@ -3542,7 +3537,7 @@ class Booster:
num_iteration = -1 num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration, return predictor.predict(data, start_iteration, num_iteration,
raw_score, pred_leaf, pred_contrib, raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape) data_has_header)
def refit( def refit(
self, self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment