Unverified Commit eb34762e authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on raw data passed to Dataset and Booster (#5752)

parent 356a7806
...@@ -35,6 +35,17 @@ _LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]] ...@@ -35,6 +35,17 @@ _LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool] _LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], str] _LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], str]
_LGBM_FeatureNameConfiguration = Union[List[str], str] _LGBM_FeatureNameConfiguration = Union[List[str], str]
_LGBM_TrainDataType = Union[
str,
Path,
np.ndarray,
pd_DataFrame,
dt_DataTable,
scipy.sparse.spmatrix,
"Sequence",
List["Sequence"],
List[np.ndarray]
]
_LGBM_LabelType = Union[ _LGBM_LabelType = Union[
list, list,
np.ndarray, np.ndarray,
...@@ -1400,7 +1411,7 @@ class Dataset: ...@@ -1400,7 +1411,7 @@ class Dataset:
def __init__( def __init__(
self, self,
data, data: _LGBM_TrainDataType,
label: Optional[_LGBM_LabelType] = None, label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None, reference: Optional["Dataset"] = None,
weight=None, weight=None,
...@@ -1708,7 +1719,7 @@ class Dataset: ...@@ -1708,7 +1719,7 @@ class Dataset:
def _lazy_init( def _lazy_init(
self, self,
data, data: Optional[_LGBM_TrainDataType],
label: Optional[_LGBM_LabelType] = None, label: Optional[_LGBM_LabelType] = None,
reference: Optional["Dataset"] = None, reference: Optional["Dataset"] = None,
weight=None, weight=None,
...@@ -2146,7 +2157,7 @@ class Dataset: ...@@ -2146,7 +2157,7 @@ class Dataset:
def create_valid( def create_valid(
self, self,
data, data: _LGBM_TrainDataType,
label: Optional[_LGBM_LabelType] = None, label: Optional[_LGBM_LabelType] = None,
weight=None, weight=None,
group=None, group=None,
...@@ -2673,7 +2684,7 @@ class Dataset: ...@@ -2673,7 +2684,7 @@ class Dataset:
self.init_score = self.get_field('init_score') self.init_score = self.get_field('init_score')
return self.init_score return self.init_score
def get_data(self): def get_data(self) -> Optional[_LGBM_TrainDataType]:
"""Get the raw data of the Dataset. """Get the raw data of the Dataset.
Returns Returns
...@@ -4016,7 +4027,7 @@ class Booster: ...@@ -4016,7 +4027,7 @@ class Booster:
def refit( def refit(
self, self,
data, data: _LGBM_TrainDataType,
label, label,
decay_rate: float = 0.9, decay_rate: float = 0.9,
reference: Optional[Dataset] = None, reference: Optional[Dataset] = None,
...@@ -4034,7 +4045,7 @@ class Booster: ...@@ -4034,7 +4045,7 @@ class Booster:
Parameters Parameters
---------- ----------
data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse data : str, pathlib.Path, numpy array, pandas DataFrame, H2O DataTable's Frame, scipy.sparse, Sequence, list of Sequence or list of numpy array
Data source for refit. Data source for refit.
If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM). If str or pathlib.Path, it represents the path to a text file (CSV, TSV, or LibSVM).
label : list, numpy 1-D array or pandas Series / one-column DataFrame label : list, numpy 1-D array or pandas Series / one-column DataFrame
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment