"docs/zh_cn/user_guides/test.md" did not exist on "fdfe3c4f8ba935ae428a8a496ce57755d5b2ea98"
Unverified Commit e4231205 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] use keyword arguments in predict() calls (#5755)

parent 27e69e75
......@@ -4052,9 +4052,16 @@ class Booster:
num_iteration = self.best_iteration
else:
num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration,
raw_score, pred_leaf, pred_contrib,
data_has_header, validate_features)
return predictor.predict(
data=data,
start_iteration=start_iteration,
num_iteration=num_iteration,
raw_score=raw_score,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
data_has_header=data_has_header,
validate_features=validate_features
)
def refit(
self,
......@@ -4130,7 +4137,12 @@ class Booster:
if dataset_params is None:
dataset_params = {}
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True, validate_features=validate_features)
leaf_preds = predictor.predict(
data=data,
start_iteration=-1,
pred_leaf=True,
validate_features=validate_features
)
nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetLinear(
......
......@@ -1135,9 +1135,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
**kwargs: Any
):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, validate_features,
**kwargs)
result = self.predict_proba(
X=X,
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result
else:
......@@ -1158,7 +1165,16 @@ class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
**kwargs: Any
):
"""Docstring is set after definition, using a template."""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs)
result = super().predict(
X=X,
raw_score=raw_score,
start_iteration=start_iteration,
num_iteration=num_iteration,
pred_leaf=pred_leaf,
pred_contrib=pred_contrib,
validate_features=validate_features,
**kwargs
)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
_log_warning("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment