Unverified Commit 55a31bfe authored by Deddy Jobson's avatar Deddy Jobson Committed by GitHub
Browse files

[python-package] Add type hints to the callback file (#4093)



* added type hints; implemented one workaround

* resolving some linting errors

* Added doc strings

* fixed more linting errors

* Made documentation more imperative.

* removed one type hint

* more specific type hinting
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* added import

* Apply suggestions from code review
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* made a class and function private

* Apply suggestions from code review

Make the documentation clearer.
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>

* linting error fix

* more linting errors

* removing the decorator

* ignore mypy function attribute errors

* fix lints
Co-authored-by: default avatarJames Lamb <jaylamb20@gmail.com>
parent 536946e3
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
"""Callbacks library.""" """Callbacks library."""
import collections import collections
from operator import gt, lt from operator import gt, lt
from typing import Any, Callable, Dict, List, Union
from .basic import _ConfigAliases, _log_info, _log_warning from .basic import _ConfigAliases, _log_info, _log_warning
...@@ -9,7 +10,7 @@ from .basic import _ConfigAliases, _log_info, _log_warning ...@@ -9,7 +10,7 @@ from .basic import _ConfigAliases, _log_info, _log_warning
class EarlyStopException(Exception): class EarlyStopException(Exception):
"""Exception of early stopping.""" """Exception of early stopping."""
def __init__(self, best_iteration, best_score): def __init__(self, best_iteration: int, best_score: float) -> None:
"""Create early stopping exception. """Create early stopping exception.
Parameters Parameters
...@@ -35,7 +36,7 @@ CallbackEnv = collections.namedtuple( ...@@ -35,7 +36,7 @@ CallbackEnv = collections.namedtuple(
"evaluation_result_list"]) "evaluation_result_list"])
def _format_eval_result(value, show_stdv=True): def _format_eval_result(value: list, show_stdv: bool = True) -> str:
"""Format metric string.""" """Format metric string."""
if len(value) == 4: if len(value) == 4:
return '%s\'s %s: %g' % (value[0], value[1], value[2]) return '%s\'s %s: %g' % (value[0], value[1], value[2])
...@@ -48,7 +49,7 @@ def _format_eval_result(value, show_stdv=True): ...@@ -48,7 +49,7 @@ def _format_eval_result(value, show_stdv=True):
raise ValueError("Wrong metric value") raise ValueError("Wrong metric value")
def print_evaluation(period=1, show_stdv=True): def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
"""Create a callback that prints the evaluation results. """Create a callback that prints the evaluation results.
Parameters Parameters
...@@ -63,15 +64,15 @@ def print_evaluation(period=1, show_stdv=True): ...@@ -63,15 +64,15 @@ def print_evaluation(period=1, show_stdv=True):
callback : function callback : function
The callback that prints the evaluation results every ``period`` iteration(s). The callback that prints the evaluation results every ``period`` iteration(s).
""" """
def _callback(env): def _callback(env: CallbackEnv) -> None:
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0: if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list]) result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
_log_info('[%d]\t%s' % (env.iteration + 1, result)) _log_info('[%d]\t%s' % (env.iteration + 1, result))
_callback.order = 10 _callback.order = 10 # type: ignore
return _callback return _callback
def record_evaluation(eval_result): def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
"""Create a callback that records the evaluation history into ``eval_result``. """Create a callback that records the evaluation history into ``eval_result``.
Parameters Parameters
...@@ -88,21 +89,21 @@ def record_evaluation(eval_result): ...@@ -88,21 +89,21 @@ def record_evaluation(eval_result):
raise TypeError('eval_result should be a dictionary') raise TypeError('eval_result should be a dictionary')
eval_result.clear() eval_result.clear()
def _init(env): def _init(env: CallbackEnv) -> None:
for data_name, eval_name, _, _ in env.evaluation_result_list: for data_name, eval_name, _, _ in env.evaluation_result_list:
eval_result.setdefault(data_name, collections.OrderedDict()) eval_result.setdefault(data_name, collections.OrderedDict())
eval_result[data_name].setdefault(eval_name, []) eval_result[data_name].setdefault(eval_name, [])
def _callback(env): def _callback(env: CallbackEnv) -> None:
if not eval_result: if not eval_result:
_init(env) _init(env)
for data_name, eval_name, result, _ in env.evaluation_result_list: for data_name, eval_name, result, _ in env.evaluation_result_list:
eval_result[data_name][eval_name].append(result) eval_result[data_name][eval_name].append(result)
_callback.order = 20 _callback.order = 20 # type: ignore
return _callback return _callback
def reset_parameter(**kwargs): def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
"""Create a callback that resets the parameter after the first iteration. """Create a callback that resets the parameter after the first iteration.
.. note:: .. note::
...@@ -123,7 +124,7 @@ def reset_parameter(**kwargs): ...@@ -123,7 +124,7 @@ def reset_parameter(**kwargs):
callback : function callback : function
The callback that resets the parameter after the first iteration. The callback that resets the parameter after the first iteration.
""" """
def _callback(env): def _callback(env: CallbackEnv) -> None:
new_parameters = {} new_parameters = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if isinstance(value, list): if isinstance(value, list):
...@@ -138,12 +139,12 @@ def reset_parameter(**kwargs): ...@@ -138,12 +139,12 @@ def reset_parameter(**kwargs):
if new_parameters: if new_parameters:
env.model.reset_parameter(new_parameters) env.model.reset_parameter(new_parameters)
env.params.update(new_parameters) env.params.update(new_parameters)
_callback.before_iteration = True _callback.before_iteration = True # type: ignore
_callback.order = 10 _callback.order = 10 # type: ignore
return _callback return _callback
def early_stopping(stopping_rounds, first_metric_only=False, verbose=True): def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
"""Create a callback that activates early stopping. """Create a callback that activates early stopping.
Activates early stopping. Activates early stopping.
...@@ -170,12 +171,12 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True): ...@@ -170,12 +171,12 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
""" """
best_score = [] best_score = []
best_iter = [] best_iter = []
best_score_list = [] best_score_list: list = []
cmp_op = [] cmp_op = []
enabled = [True] enabled = [True]
first_metric = [''] first_metric = ['']
def _init(env): def _init(env: CallbackEnv) -> None:
enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting")) in _ConfigAliases.get("boosting"))
if not enabled[0]: if not enabled[0]:
...@@ -200,7 +201,7 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True): ...@@ -200,7 +201,7 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
best_score.append(float('inf')) best_score.append(float('inf'))
cmp_op.append(lt) cmp_op.append(lt)
def _final_iteration_check(env, eval_name_splitted, i): def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
if env.iteration == env.end_iteration - 1: if env.iteration == env.end_iteration - 1:
if verbose: if verbose:
_log_info('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % ( _log_info('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
...@@ -209,7 +210,7 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True): ...@@ -209,7 +210,7 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
_log_info("Evaluated only: {}".format(eval_name_splitted[-1])) _log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
raise EarlyStopException(best_iter[i], best_score_list[i]) raise EarlyStopException(best_iter[i], best_score_list[i])
def _callback(env): def _callback(env: CallbackEnv) -> None:
if not cmp_op: if not cmp_op:
_init(env) _init(env)
if not enabled[0]: if not enabled[0]:
...@@ -236,5 +237,5 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True): ...@@ -236,5 +237,5 @@ def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
_log_info("Evaluated only: {}".format(eval_name_splitted[-1])) _log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
raise EarlyStopException(best_iter[i], best_score_list[i]) raise EarlyStopException(best_iter[i], best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i) _final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30 _callback.order = 30 # type: ignore
return _callback return _callback
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment