Unverified Commit 560c884d authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add some type hints on Booster (#5302)

parent 70654048
...@@ -2558,7 +2558,13 @@ class Dataset: ...@@ -2558,7 +2558,13 @@ class Dataset:
class Booster: class Booster:
"""Booster in LightGBM.""" """Booster in LightGBM."""
def __init__(self, params=None, train_set=None, model_file=None, model_str=None): def __init__(
self,
params: Optional[Dict[str, Any]] = None,
train_set: Optional[Dataset] = None,
model_file: Optional[Union[str, Path]] = None,
model_str: Optional[str] = None
):
"""Initialize the Booster. """Initialize the Booster.
Parameters Parameters
...@@ -2670,7 +2676,7 @@ class Booster: ...@@ -2670,7 +2676,7 @@ class Booster:
'to create Booster instance') 'to create Booster instance')
self.params = params self.params = params
def __del__(self): def __del__(self) -> None:
try: try:
if self.network: if self.network:
self.free_network() self.free_network()
...@@ -2682,10 +2688,10 @@ class Booster: ...@@ -2682,10 +2688,10 @@ class Booster:
except AttributeError: except AttributeError:
pass pass
def __copy__(self): def __copy__(self) -> "Booster":
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, _): def __deepcopy__(self, _) -> "Booster":
model_str = self.model_to_string(num_iteration=-1) model_str = self.model_to_string(num_iteration=-1)
booster = Booster(model_str=model_str) booster = Booster(model_str=model_str)
return booster return booster
...@@ -2711,7 +2717,7 @@ class Booster: ...@@ -2711,7 +2717,7 @@ class Booster:
state['handle'] = handle state['handle'] = handle
self.__dict__.update(state) self.__dict__.update(state)
def free_dataset(self): def free_dataset(self) -> "Booster":
"""Free Booster's Datasets. """Free Booster's Datasets.
Returns Returns
...@@ -2724,7 +2730,7 @@ class Booster: ...@@ -2724,7 +2730,7 @@ class Booster:
self.__num_dataset = 0 self.__num_dataset = 0
return self return self
def _free_buffer(self): def _free_buffer(self) -> "Booster":
self.__inner_predict_buffer = [] self.__inner_predict_buffer = []
self.__is_predicted_cur_iter = [] self.__is_predicted_cur_iter = []
return self return self
...@@ -2763,7 +2769,7 @@ class Booster: ...@@ -2763,7 +2769,7 @@ class Booster:
self.network = True self.network = True
return self return self
def free_network(self): def free_network(self) -> "Booster":
"""Free Booster's network. """Free Booster's network.
Returns Returns
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment