Unverified Commit 946817a5 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors in engine.py (#4839)



* [python-package] fix mypy errors in engine.py

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* allow for stdv

* whitespace
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent de23b562
...@@ -2,10 +2,15 @@ ...@@ -2,10 +2,15 @@
"""Callbacks library.""" """Callbacks library."""
import collections import collections
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Union from typing import Any, Callable, Dict, List, Tuple, Union
from .basic import _ConfigAliases, _log_info, _log_warning from .basic import _ConfigAliases, _log_info, _log_warning
_EvalResultTuple = Union[
List[Tuple[str, str, float, bool]],
List[Tuple[str, str, float, bool, float]]
]
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool: def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta return curr_score > best_score + delta
...@@ -18,15 +23,15 @@ def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool: ...@@ -18,15 +23,15 @@ def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
class EarlyStopException(Exception): class EarlyStopException(Exception):
"""Exception of early stopping.""" """Exception of early stopping."""
def __init__(self, best_iteration: int, best_score: float) -> None: def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
"""Create early stopping exception. """Create early stopping exception.
Parameters Parameters
---------- ----------
best_iteration : int best_iteration : int
The best iteration stopped. The best iteration stopped.
best_score : float best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
The score of the best iteration. Scores for each metric, on each validation set, as of the best iteration.
""" """
super().__init__() super().__init__()
self.best_iteration = best_iteration self.best_iteration = best_iteration
...@@ -44,7 +49,7 @@ CallbackEnv = collections.namedtuple( ...@@ -44,7 +49,7 @@ CallbackEnv = collections.namedtuple(
"evaluation_result_list"]) "evaluation_result_list"])
def _format_eval_result(value: list, show_stdv: bool = True) -> str: def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
"""Format metric string.""" """Format metric string."""
if len(value) == 4: if len(value) == 4:
return f"{value[0]}'s {value[1]}: {value[2]:g}" return f"{value[0]}'s {value[1]}: {value[2]:g}"
......
...@@ -223,38 +223,38 @@ def train( ...@@ -223,38 +223,38 @@ def train(
name_valid_sets.append(f'valid_{i}') name_valid_sets.append(f'valid_{i}')
# process callbacks # process callbacks
if callbacks is None: if callbacks is None:
callbacks = set() callbacks_set = set()
else: else:
for i, cb in enumerate(callbacks): for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks)) cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks) callbacks_set = set(callbacks)
# Most of legacy advanced options becomes callbacks # Most of legacy advanced options becomes callbacks
if verbose_eval != "warn": if verbose_eval != "warn":
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. " _log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'log_evaluation()' callback via 'callbacks' argument instead.") "Pass 'log_evaluation()' callback via 'callbacks' argument instead.")
else: else:
if callbacks: # assume user has already specified log_evaluation callback if callbacks_set: # assume user has already specified log_evaluation callback
verbose_eval = False verbose_eval = False
else: else:
verbose_eval = True verbose_eval = True
if verbose_eval is True: if verbose_eval is True:
callbacks.add(callback.log_evaluation()) callbacks_set.add(callback.log_evaluation())
elif isinstance(verbose_eval, int): elif isinstance(verbose_eval, int):
callbacks.add(callback.log_evaluation(verbose_eval)) callbacks_set.add(callback.log_evaluation(verbose_eval))
if early_stopping_rounds is not None and early_stopping_rounds > 0: if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval))) callbacks_set.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
if evals_result is not None: if evals_result is not None:
_log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. " _log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'record_evaluation()' callback via 'callbacks' argument instead.") "Pass 'record_evaluation()' callback via 'callbacks' argument instead.")
callbacks.add(callback.record_evaluation(evals_result)) callbacks_set.add(callback.record_evaluation(evals_result))
callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)} callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, 'before_iteration', False)}
callbacks_after_iter = callbacks - callbacks_before_iter callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set
callbacks_before_iter = sorted(callbacks_before_iter, key=attrgetter('order')) callbacks_before_iter = sorted(callbacks_before_iter_set, key=attrgetter('order'))
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order')) callbacks_after_iter = sorted(callbacks_after_iter_set, key=attrgetter('order'))
# construct booster # construct booster
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment