Unverified Commit e4231205 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] use keyword arguments in predict() calls (#5755)

parent 27e69e75
...@@ -4052,9 +4052,16 @@ class Booster: ...@@ -4052,9 +4052,16 @@ class Booster:
num_iteration = self.best_iteration num_iteration = self.best_iteration
else: else:
num_iteration = -1 num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration, return predictor.predict(
raw_score, pred_leaf, pred_contrib, data=data,
data_has_header, validate_features) start_iteration=start_iteration,
num_iteration=num_iteration,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
data_has_header=data_has_header,
validate_features=validate_features
)
def refit( def refit(
self, self,
...@@ -4130,7 +4137,12 @@ class Booster: ...@@ -4130,7 +4137,12 @@ class Booster:
if dataset_params is None: if dataset_params is None:
dataset_params = {} dataset_params = {}
predictor = self._to_predictor(deepcopy(kwargs)) predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True, validate_features=validate_features) leaf_preds = predictor.predict(
data=data,
start_iteration=-1,
pred_leaf=True,
validate_features=validate_features
)
nrow, ncol = leaf_preds.shape nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_int(0) out_is_linear = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetLinear( _safe_call(_LIB.LGBM_BoosterGetLinear(
......
...@@ -1135,9 +1135,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1135,9 +1135,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
**kwargs: Any **kwargs: Any
): ):
"""Docstring is inherited from the LGBMModel.""" """Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration, result = self.predict_proba(
pred_leaf, pred_contrib, validate_features, X=X,
**kwargs) raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
if callable(self._objective) or raw_score or pred_leaf or pred_contrib: if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result return result
else: else:
...@@ -1158,7 +1165,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel): ...@@ -1158,7 +1165,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
**kwargs: Any **kwargs: Any
): ):
"""Docstring is set after definition, using a template.""" """Docstring is set after definition, using a template."""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs) result = super().predict(
X=X,
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib): if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
_log_warning("Cannot compute class probabilities or labels " _log_warning("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n" "due to the usage of customized objective function.\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment