Unverified Commit 426dfcca authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Booster predict() methods (#5749)

parent e8cdc2c9
...@@ -41,6 +41,14 @@ _LGBM_LabelType = Union[ ...@@ -41,6 +41,14 @@ _LGBM_LabelType = Union[
pd_Series, pd_Series,
pd_DataFrame pd_DataFrame
] ]
_LGBM_PredictDataType = Union[
str,
Path,
np.ndarray,
pd_DataFrame,
dt_DataTable,
scipy.sparse.spmatrix
]
ZERO_THRESHOLD = 1e-35 ZERO_THRESHOLD = 1e-35
...@@ -826,7 +834,7 @@ class _InnerPredictor: ...@@ -826,7 +834,7 @@ class _InnerPredictor:
def predict( def predict(
self, self,
data, data: _LGBM_PredictDataType,
start_iteration: int = 0, start_iteration: int = 0,
num_iteration: int = -1, num_iteration: int = -1,
raw_score: bool = False, raw_score: bool = False,
...@@ -834,7 +842,7 @@ class _InnerPredictor: ...@@ -834,7 +842,7 @@ class _InnerPredictor:
pred_contrib: bool = False, pred_contrib: bool = False,
data_has_header: bool = False, data_has_header: bool = False,
validate_features: bool = False validate_features: bool = False
): ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
"""Predict logic. """Predict logic.
Parameters Parameters
...@@ -3941,7 +3949,7 @@ class Booster: ...@@ -3941,7 +3949,7 @@ class Booster:
def predict( def predict(
self, self,
data, data: _LGBM_PredictDataType,
start_iteration: int = 0, start_iteration: int = 0,
num_iteration: Optional[int] = None, num_iteration: Optional[int] = None,
raw_score: bool = False, raw_score: bool = False,
...@@ -3950,7 +3958,7 @@ class Booster: ...@@ -3950,7 +3958,7 @@ class Booster:
data_has_header: bool = False, data_has_header: bool = False,
validate_features: bool = False, validate_features: bool = False,
**kwargs: Any **kwargs: Any
): ) -> Union[np.ndarray, scipy.sparse.spmatrix, List[scipy.sparse.spmatrix]]:
"""Make a prediction. """Make a prediction.
Parameters Parameters
...@@ -4021,7 +4029,7 @@ class Booster: ...@@ -4021,7 +4029,7 @@ class Booster:
free_raw_data: bool = True, free_raw_data: bool = True,
validate_features: bool = False, validate_features: bool = False,
**kwargs **kwargs
): ) -> "Booster":
"""Refit the existing Booster by new data. """Refit the existing Booster by new data.
Parameters Parameters
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment