Unverified Commit 6ab58edc authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Dataset (#5353)

parent 25e32e94
...@@ -1227,7 +1227,7 @@ class Dataset: ...@@ -1227,7 +1227,7 @@ class Dataset:
self.version = 0 self.version = 0
self._start_row = 0 # Used when pushing rows one by one. self._start_row = 0 # Used when pushing rows one by one.
def __del__(self): def __del__(self) -> None:
try: try:
self._free_handle() self._free_handle()
except AttributeError: except AttributeError:
...@@ -1412,7 +1412,7 @@ class Dataset: ...@@ -1412,7 +1412,7 @@ class Dataset:
else: else:
return {} return {}
def _free_handle(self): def _free_handle(self) -> "Dataset":
if self.handle is not None: if self.handle is not None:
_safe_call(_LIB.LGBM_DatasetFree(self.handle)) _safe_call(_LIB.LGBM_DatasetFree(self.handle))
self.handle = None self.handle = None
...@@ -1789,7 +1789,7 @@ class Dataset: ...@@ -1789,7 +1789,7 @@ class Dataset:
return False return False
return True return True
def construct(self): def construct(self) -> "Dataset":
"""Lazy init. """Lazy init.
Returns Returns
...@@ -1886,7 +1886,11 @@ class Dataset: ...@@ -1886,7 +1886,11 @@ class Dataset:
ret.pandas_categorical = self.pandas_categorical ret.pandas_categorical = self.pandas_categorical
return ret return ret
def subset(self, used_indices, params=None): def subset(
self,
used_indices: List[int],
params: Optional[Dict[str, Any]] = None
) -> "Dataset":
"""Get subset of current Dataset. """Get subset of current Dataset.
Parameters Parameters
...@@ -1911,7 +1915,7 @@ class Dataset: ...@@ -1911,7 +1915,7 @@ class Dataset:
ret.used_indices = sorted(used_indices) ret.used_indices = sorted(used_indices)
return ret return ret
def save_binary(self, filename): def save_binary(self, filename: Union[str, Path]) -> "Dataset":
"""Save Dataset to a binary file. """Save Dataset to a binary file.
.. note:: .. note::
...@@ -1961,7 +1965,7 @@ class Dataset: ...@@ -1961,7 +1965,7 @@ class Dataset:
raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8')) raise LightGBMError(_LIB.LGBM_GetLastError().decode('utf-8'))
return self return self
def _reverse_update_params(self): def _reverse_update_params(self) -> "Dataset":
if self.handle is None: if self.handle is None:
self.params = deepcopy(self.params_back_up) self.params = deepcopy(self.params_back_up)
self.params_back_up = None self.params_back_up = None
...@@ -2026,7 +2030,7 @@ class Dataset: ...@@ -2026,7 +2030,7 @@ class Dataset:
self.version += 1 self.version += 1
return self return self
def get_field(self, field_name): def get_field(self, field_name: str) -> Optional[np.ndarray]:
"""Get property from the Dataset. """Get property from the Dataset.
Parameters Parameters
...@@ -2069,7 +2073,10 @@ class Dataset: ...@@ -2069,7 +2073,10 @@ class Dataset:
arr = arr.reshape((num_data, num_classes), order='F') arr = arr.reshape((num_data, num_classes), order='F')
return arr return arr
def set_categorical_feature(self, categorical_feature): def set_categorical_feature(
self,
categorical_feature: Union[List[int], List[str]]
) -> "Dataset":
"""Set categorical features. """Set categorical features.
Parameters Parameters
...@@ -2147,7 +2154,7 @@ class Dataset: ...@@ -2147,7 +2154,7 @@ class Dataset:
raise LightGBMError("Cannot set reference after freed raw data, " raise LightGBMError("Cannot set reference after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.") "set free_raw_data=False when construct Dataset to avoid this.")
def set_feature_name(self, feature_name): def set_feature_name(self, feature_name: List[str]) -> "Dataset":
"""Set feature name. """Set feature name.
Parameters Parameters
...@@ -2256,7 +2263,7 @@ class Dataset: ...@@ -2256,7 +2263,7 @@ class Dataset:
self.set_field('group', group) self.set_field('group', group)
return self return self
def get_feature_name(self): def get_feature_name(self) -> List[str]:
"""Get the names of columns (features) in the Dataset. """Get the names of columns (features) in the Dataset.
Returns Returns
...@@ -2382,7 +2389,7 @@ class Dataset: ...@@ -2382,7 +2389,7 @@ class Dataset:
self.group = np.diff(self.group) self.group = np.diff(self.group)
return self.group return self.group
def num_data(self): def num_data(self) -> int:
"""Get the number of rows in the Dataset. """Get the number of rows in the Dataset.
Returns Returns
...@@ -2398,7 +2405,7 @@ class Dataset: ...@@ -2398,7 +2405,7 @@ class Dataset:
else: else:
raise LightGBMError("Cannot get num_data before construct dataset") raise LightGBMError("Cannot get num_data before construct dataset")
def num_feature(self): def num_feature(self) -> int:
"""Get the number of columns (features) in the Dataset. """Get the number of columns (features) in the Dataset.
Returns Returns
...@@ -2468,7 +2475,7 @@ class Dataset: ...@@ -2468,7 +2475,7 @@ class Dataset:
break break
return ref_chain return ref_chain
def add_features_from(self, other): def add_features_from(self, other: "Dataset") -> "Dataset":
"""Add features from other Dataset to the current Dataset. """Add features from other Dataset to the current Dataset.
Both Datasets must be constructed before calling this method. Both Datasets must be constructed before calling this method.
...@@ -2557,7 +2564,7 @@ class Dataset: ...@@ -2557,7 +2564,7 @@ class Dataset:
self.pandas_categorical = None self.pandas_categorical = None
return self return self
def _dump_text(self, filename): def _dump_text(self, filename: Union[str, Path]) -> "Dataset":
"""Save Dataset to a text file. """Save Dataset to a text file.
This format cannot be loaded back in by LightGBM, but is useful for debugging purposes. This format cannot be loaded back in by LightGBM, but is useful for debugging purposes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment