Unverified Commit c8712a9e authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Dataset constructors (#5458)

* [python-package] add type hints on Dataset constructors

* fix __init_from_list_np2d() hint

* add return type

* define a DatasetHandle type
parent d0ea321c
...@@ -20,6 +20,7 @@ import scipy.sparse ...@@ -20,6 +20,7 @@ import scipy.sparse
from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series from .compat import PANDAS_INSTALLED, concat, dt_DataTable, pd_CategoricalDtype, pd_DataFrame, pd_Series
from .libpath import find_lib_path from .libpath import find_lib_path
_DatasetHandle = ctypes.c_void_p
_LGBM_EvalFunctionResultType = Tuple[str, float, bool] _LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool] _LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
...@@ -1196,10 +1197,19 @@ class _InnerPredictor: ...@@ -1196,10 +1197,19 @@ class _InnerPredictor:
class Dataset: class Dataset:
"""Dataset in LightGBM.""" """Dataset in LightGBM."""
def __init__(self, data, label=None, reference=None, def __init__(
weight=None, group=None, init_score=None, self,
feature_name='auto', categorical_feature='auto', params=None, data,
free_raw_data=True): label=None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
init_score=None,
feature_name='auto',
categorical_feature='auto',
params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True
):
"""Initialize Dataset. """Initialize Dataset.
Parameters Parameters
...@@ -1488,9 +1498,19 @@ class Dataset: ...@@ -1488,9 +1498,19 @@ class Dataset:
return self return self
self.set_init_score(init_score) self.set_init_score(init_score)
def _lazy_init(self, data, label=None, reference=None, def _lazy_init(
weight=None, group=None, init_score=None, predictor=None, self,
feature_name='auto', categorical_feature='auto', params=None): data,
label=None,
reference: Optional["Dataset"] = None,
weight=None,
group=None,
init_score=None,
predictor=None,
feature_name='auto',
categorical_feature='auto',
params: Optional[Dict[str, Any]] = None
) -> "Dataset":
if data is None: if data is None:
self.handle = None self.handle = None
return self return self
...@@ -1635,7 +1655,11 @@ class Dataset: ...@@ -1635,7 +1655,11 @@ class Dataset:
return filtered, filtered_idx return filtered, filtered_idx
def __init_from_seqs(self, seqs: List[Sequence], ref_dataset: Optional['Dataset'] = None): def __init_from_seqs(
self,
seqs: List[Sequence],
ref_dataset: Optional["Dataset"] = None
) -> "Dataset":
""" """
Initialize data from list of Sequence objects. Initialize data from list of Sequence objects.
...@@ -1664,7 +1688,12 @@ class Dataset: ...@@ -1664,7 +1688,12 @@ class Dataset:
self._push_rows(seq[start:end]) self._push_rows(seq[start:end])
return self return self
def __init_from_np2d(self, mat, params_str, ref_dataset): def __init_from_np2d(
self,
mat: np.ndarray,
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a 2-D numpy matrix.""" """Initialize data from a 2-D numpy matrix."""
if len(mat.shape) != 2: if len(mat.shape) != 2:
raise ValueError('Input numpy.ndarray must be 2 dimensional') raise ValueError('Input numpy.ndarray must be 2 dimensional')
...@@ -1687,7 +1716,12 @@ class Dataset: ...@@ -1687,7 +1716,12 @@ class Dataset:
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
return self return self
def __init_from_list_np2d(self, mats, params_str, ref_dataset): def __init_from_list_np2d(
self,
mats: List[np.ndarray],
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a list of 2-D numpy matrices.""" """Initialize data from a list of 2-D numpy matrices."""
ncol = mats[0].shape[1] ncol = mats[0].shape[1]
nrow = np.empty((len(mats),), np.int32) nrow = np.empty((len(mats),), np.int32)
...@@ -1733,7 +1767,12 @@ class Dataset: ...@@ -1733,7 +1767,12 @@ class Dataset:
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
return self return self
def __init_from_csr(self, csr, params_str, ref_dataset): def __init_from_csr(
self,
csr: scipy.sparse.csr_matrix,
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a CSR matrix.""" """Initialize data from a CSR matrix."""
if len(csr.indices) != len(csr.data): if len(csr.indices) != len(csr.data):
raise ValueError(f'Length mismatch: {len(csr.indices)} vs {len(csr.data)}') raise ValueError(f'Length mismatch: {len(csr.indices)} vs {len(csr.data)}')
...@@ -1759,7 +1798,12 @@ class Dataset: ...@@ -1759,7 +1798,12 @@ class Dataset:
ctypes.byref(self.handle))) ctypes.byref(self.handle)))
return self return self
def __init_from_csc(self, csc, params_str, ref_dataset): def __init_from_csc(
self,
csc: scipy.sparse.csc_matrix,
params_str: str,
ref_dataset: Optional[_DatasetHandle]
) -> "Dataset":
"""Initialize data from a CSC matrix.""" """Initialize data from a CSC matrix."""
if len(csc.indices) != len(csc.data): if len(csc.indices) != len(csc.data):
raise ValueError(f'Length mismatch: {len(csc.indices)} vs {len(csc.data)}') raise ValueError(f'Length mismatch: {len(csc.indices)} vs {len(csc.data)}')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment