Unverified Commit 273c9b04 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] add type hints on Booster.update() (#5359)

* [python-package] add type hints on Booster.update()

* add __boost()
parent a5b2e049
...@@ -2585,6 +2585,12 @@ class Dataset: ...@@ -2585,6 +2585,12 @@ class Dataset:
return self return self
_LGBM_CustomObjectiveFunction = Callable[
[np.ndarray, Dataset],
Tuple[np.ndarray, np.ndarray]
]
class Booster: class Booster:
"""Booster in LightGBM.""" """Booster in LightGBM."""
...@@ -3014,7 +3020,11 @@ class Booster: ...@@ -3014,7 +3020,11 @@ class Booster:
self.params.update(params) self.params.update(params)
return self return self
def update(self, train_set=None, fobj=None): def update(
self,
train_set: Optional[Dataset] = None,
fobj: Optional[_LGBM_CustomObjectiveFunction] = None
) -> bool:
"""Update Booster for one iteration. """Update Booster for one iteration.
Parameters Parameters
...@@ -3081,7 +3091,11 @@ class Booster: ...@@ -3081,7 +3091,11 @@ class Booster:
grad, hess = fobj(self.__inner_predict(0), self.train_set) grad, hess = fobj(self.__inner_predict(0), self.train_set)
return self.__boost(grad, hess) return self.__boost(grad, hess)
def __boost(self, grad, hess): def __boost(
self,
grad: np.ndarray,
hess: np.ndarray
) -> bool:
"""Boost Booster for one iteration with customized gradient statistics. """Boost Booster for one iteration with customized gradient statistics.
.. note:: .. note::
......
...@@ -9,7 +9,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union ...@@ -9,7 +9,8 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from . import callback from . import callback
from .basic import Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, _log_warning from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_CustomObjectiveFunction, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold
_LGBM_CustomMetricFunction = Callable[ _LGBM_CustomMetricFunction = Callable[
...@@ -131,7 +132,7 @@ def train( ...@@ -131,7 +132,7 @@ def train(
params=params, params=params,
default_value=None default_value=None
) )
fobj = None fobj: Optional[_LGBM_CustomObjectiveFunction] = None
if callable(params["objective"]): if callable(params["objective"]):
fobj = params["objective"] fobj = params["objective"]
params["objective"] = 'none' params["objective"] = 'none'
...@@ -523,7 +524,7 @@ def cv( ...@@ -523,7 +524,7 @@ def cv(
params=params, params=params,
default_value=None default_value=None
) )
fobj = None fobj: Optional[_LGBM_CustomObjectiveFunction] = None
if callable(params["objective"]): if callable(params["objective"]):
fobj = params["objective"] fobj = params["objective"]
params["objective"] = 'none' params["objective"] = 'none'
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment