"include/LightGBM/vscode:/vscode.git/clone" did not exist on "b24190a68fbd17c4a7eb42e679f6ce99b5132e34"
Unverified Commit 78207462 authored by José Morales's avatar José Morales Committed by GitHub
Browse files

[python-package] remove is_reshape argument in Booster.predict (fixes #5115) (#5117)

parent db089854
......@@ -752,8 +752,7 @@ class _InnerPredictor:
return this
def predict(self, data, start_iteration=0, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False,
is_reshape=True):
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False):
"""Predict logic.
Parameters
......@@ -774,8 +773,6 @@ class _InnerPredictor:
data_has_header : bool, optional (default=False)
Whether data has header.
Used only for txt data.
is_reshape : bool, optional (default=True)
Whether to reshape to (nrow, ncol).
Returns
-------
......@@ -832,7 +829,7 @@ class _InnerPredictor:
if pred_leaf:
preds = preds.astype(np.int32)
is_sparse = scipy.sparse.issparse(preds) or isinstance(preds, list)
if is_reshape and not is_sparse and preds.size != nrow:
if not is_sparse and preds.size != nrow:
if preds.size % nrow == 0:
preds = preds.reshape(nrow, -1)
else:
......@@ -1403,8 +1400,8 @@ class Dataset:
if predictor is not None:
init_score = predictor.predict(data,
raw_score=True,
data_has_header=data_has_header,
is_reshape=False)
data_has_header=data_has_header)
init_score = init_score.ravel()
if used_indices is not None:
assert not self.need_slice
if isinstance(data, (str, Path)):
......@@ -3489,7 +3486,7 @@ class Booster:
def predict(self, data, start_iteration=0, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, **kwargs):
data_has_header=False, **kwargs):
"""Make a prediction.
Parameters
......@@ -3523,8 +3520,6 @@ class Booster:
data_has_header : bool, optional (default=False)
Whether the data has header.
Used only if data is str.
is_reshape : bool, optional (default=True)
If True, result is reshaped to [nrow, ncol].
**kwargs
Other parameters for the prediction.
......@@ -3542,7 +3537,7 @@ class Booster:
num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration,
raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape)
data_has_header)
def refit(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment