"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b3c126629ee5ce83c4d9b5111eaec678af8de421"
Unverified Commit 39ed8ea2 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Dataset feature processing (#5745)

parent 00073437
......@@ -33,6 +33,8 @@ _DatasetHandle = ctypes.c_void_p
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], str]
_LGBM_FeatureNameConfiguration = Union[List[str], str]
_LGBM_LabelType = Union[
list,
np.ndarray,
......@@ -588,7 +590,12 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')
def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
def _data_from_pandas(
data,
feature_name: Optional[_LGBM_FeatureNameConfiguration],
categorical_feature: Optional[_LGBM_CategoricalFeatureConfiguration],
pandas_categorical: Optional[List[List]]
):
if isinstance(data, pd_DataFrame):
if len(data.shape) != 2 or data.shape[0] < 1:
raise ValueError('Input data must be 2 dimensional and non empty.')
......@@ -638,7 +645,10 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
return data, feature_name, categorical_feature, pandas_categorical
def _dump_pandas_categorical(pandas_categorical, file_name=None):
def _dump_pandas_categorical(
pandas_categorical: Optional[List[List]],
file_name: Optional[Union[str, Path]] = None
) -> str:
categorical_json = json.dumps(pandas_categorical, default=_json_default_with_numpy)
pandas_str = f'\npandas_categorical:{categorical_json}\n'
if file_name is not None:
......@@ -650,7 +660,7 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
def _load_pandas_categorical(
file_name: Optional[Union[str, Path]] = None,
model_str: Optional[str] = None
) -> Optional[str]:
) -> Optional[List[List]]:
pandas_key = 'pandas_categorical:'
offset = -len(pandas_key)
if file_name is not None:
......@@ -1320,8 +1330,8 @@ class Dataset:
weight=None,
group=None,
init_score=None,
feature_name='auto',
categorical_feature='auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True
):
......@@ -1371,8 +1381,8 @@ class Dataset:
self.weight = weight
self.group = group
self.init_score = init_score
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.feature_name: _LGBM_FeatureNameConfiguration = feature_name
self.categorical_feature: _LGBM_CategoricalFeatureConfiguration = categorical_feature
self.params = deepcopy(params)
self.free_raw_data = free_raw_data
self.used_indices: Optional[List[int]] = None
......@@ -2294,13 +2304,13 @@ class Dataset:
def set_categorical_feature(
self,
categorical_feature: Union[List[int], List[str], str]
categorical_feature: _LGBM_CategoricalFeatureConfiguration
) -> "Dataset":
"""Set categorical features.
Parameters
----------
categorical_feature : list of int or str
categorical_feature : list of str or int, or 'auto'
Names or indices of categorical features.
Returns
......@@ -3937,8 +3947,8 @@ class Booster:
weight=None,
group=None,
init_score=None,
feature_name: Union[str, List[str]] = 'auto',
categorical_feature: Union[str, List[str], List[int]] = 'auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
dataset_params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True,
validate_features: bool = False,
......
......@@ -11,7 +11,8 @@ import numpy as np
from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_CustomObjectiveFunction, _log_warning)
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction,
_LGBM_FeatureNameConfiguration, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
__all__ = [
......@@ -40,8 +41,8 @@ def train(
valid_names: Optional[List[str]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[List[str], str] = 'auto',
categorical_feature: Union[List[str], List[int], str] = 'auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None
) -> Booster:
......@@ -523,8 +524,8 @@ def cv(
metrics: Optional[Union[str, List[str]]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[str, List[str]] = 'auto',
categorical_feature: Union[str, List[str], List[int]] = 'auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
fpreproc: Optional[_LGBM_PreprocFunction] = None,
seed: int = 0,
callbacks: Optional[List[Callable]] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment