Unverified Commit 01774bb9 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add more type hints on Dataset (#5431)

parent 78f95e41
...@@ -619,7 +619,10 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None): ...@@ -619,7 +619,10 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
return pandas_str return pandas_str
def _load_pandas_categorical(file_name=None, model_str=None): def _load_pandas_categorical(
file_name: Optional[Union[str, Path]] = None,
model_str: Optional[str] = None
) -> Optional[str]:
pandas_key = 'pandas_categorical:' pandas_key = 'pandas_categorical:'
offset = -len(pandas_key) offset = -len(pandas_key)
if file_name is not None: if file_name is not None:
...@@ -1879,7 +1882,15 @@ class Dataset: ...@@ -1879,7 +1882,15 @@ class Dataset:
self.feature_name = self.get_feature_name() self.feature_name = self.get_feature_name()
return self return self
def create_valid(self, data, label=None, weight=None, group=None, init_score=None, params=None): def create_valid(
self,
data,
label=None,
weight=None,
group=None,
init_score=None,
params: Optional[Dict[str, Any]] = None
) -> "Dataset":
"""Create validation data align with current Dataset. """Create validation data align with current Dataset.
Parameters Parameters
...@@ -1966,7 +1977,7 @@ class Dataset: ...@@ -1966,7 +1977,7 @@ class Dataset:
c_str(str(filename)))) c_str(str(filename))))
return self return self
def _update_params(self, params): def _update_params(self, params: Optional[Dict[str, Any]]) -> "Dataset":
if not params: if not params:
return self return self
params = deepcopy(params) params = deepcopy(params)
...@@ -1999,7 +2010,11 @@ class Dataset: ...@@ -1999,7 +2010,11 @@ class Dataset:
self.params_back_up = None self.params_back_up = None
return self return self
def set_field(self, field_name, data): def set_field(
self,
field_name: str,
data
) -> "Dataset":
"""Set property into the Dataset. """Set property into the Dataset.
Parameters Parameters
...@@ -2135,7 +2150,10 @@ class Dataset: ...@@ -2135,7 +2150,10 @@ class Dataset:
raise LightGBMError("Cannot set categorical feature after freed raw data, " raise LightGBMError("Cannot set categorical feature after freed raw data, "
"set free_raw_data=False when construct Dataset to avoid this.") "set free_raw_data=False when construct Dataset to avoid this.")
def _set_predictor(self, predictor): def _set_predictor(
self,
predictor: Optional[_InnerPredictor]
) -> "Dataset":
"""Set predictor for continued training. """Set predictor for continued training.
It is not recommended for user to call this function. It is not recommended for user to call this function.
...@@ -2156,7 +2174,7 @@ class Dataset: ...@@ -2156,7 +2174,7 @@ class Dataset:
"set free_raw_data=False when construct Dataset to avoid this.") "set free_raw_data=False when construct Dataset to avoid this.")
return self return self
def set_reference(self, reference): def set_reference(self, reference: "Dataset") -> "Dataset":
"""Set reference Dataset. """Set reference Dataset.
Parameters Parameters
...@@ -2207,7 +2225,7 @@ class Dataset: ...@@ -2207,7 +2225,7 @@ class Dataset:
ctypes.c_int(len(feature_name)))) ctypes.c_int(len(feature_name))))
return self return self
def set_label(self, label): def set_label(self, label) -> "Dataset":
"""Set label of Dataset. """Set label of Dataset.
Parameters Parameters
...@@ -2227,7 +2245,7 @@ class Dataset: ...@@ -2227,7 +2245,7 @@ class Dataset:
self.label = self.get_field('label') # original values can be modified at cpp side self.label = self.get_field('label') # original values can be modified at cpp side
return self return self
def set_weight(self, weight): def set_weight(self, weight) -> "Dataset":
"""Set weight of each instance. """Set weight of each instance.
Parameters Parameters
...@@ -2249,7 +2267,7 @@ class Dataset: ...@@ -2249,7 +2267,7 @@ class Dataset:
self.weight = self.get_field('weight') # original values can be modified at cpp side self.weight = self.get_field('weight') # original values can be modified at cpp side
return self return self
def set_init_score(self, init_score): def set_init_score(self, init_score) -> "Dataset":
"""Set init score of Booster to start from. """Set init score of Booster to start from.
Parameters Parameters
...@@ -2268,7 +2286,7 @@ class Dataset: ...@@ -2268,7 +2286,7 @@ class Dataset:
self.init_score = self.get_field('init_score') # original values can be modified at cpp side self.init_score = self.get_field('init_score') # original values can be modified at cpp side
return self return self
def set_group(self, group): def set_group(self, group) -> "Dataset":
"""Set group size of Dataset (used for ranking). """Set group size of Dataset (used for ranking).
Parameters Parameters
...@@ -2330,7 +2348,7 @@ class Dataset: ...@@ -2330,7 +2348,7 @@ class Dataset:
ptr_string_buffers)) ptr_string_buffers))
return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)] return [string_buffers[i].value.decode('utf-8') for i in range(num_feature)]
def get_label(self): def get_label(self) -> Optional[np.ndarray]:
"""Get the label of the Dataset. """Get the label of the Dataset.
Returns Returns
...@@ -2342,7 +2360,7 @@ class Dataset: ...@@ -2342,7 +2360,7 @@ class Dataset:
self.label = self.get_field('label') self.label = self.get_field('label')
return self.label return self.label
def get_weight(self): def get_weight(self) -> Optional[np.ndarray]:
"""Get the weight of the Dataset. """Get the weight of the Dataset.
Returns Returns
...@@ -2354,7 +2372,7 @@ class Dataset: ...@@ -2354,7 +2372,7 @@ class Dataset:
self.weight = self.get_field('weight') self.weight = self.get_field('weight')
return self.weight return self.weight
def get_init_score(self): def get_init_score(self) -> Optional[np.ndarray]:
"""Get the initial score of the Dataset. """Get the initial score of the Dataset.
Returns Returns
...@@ -2473,7 +2491,7 @@ class Dataset: ...@@ -2473,7 +2491,7 @@ class Dataset:
else: else:
raise LightGBMError("Cannot get feature_num_bin before construct dataset") raise LightGBMError("Cannot get feature_num_bin before construct dataset")
def get_ref_chain(self, ref_limit=100): def get_ref_chain(self, ref_limit: int = 100) -> Set["Dataset"]:
"""Get a chain of Dataset objects. """Get a chain of Dataset objects.
Starts with r, then goes to r.reference (if exists), Starts with r, then goes to r.reference (if exists),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment