Unverified Commit 946817a5 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

[python-package] fix mypy errors in engine.py (#4839)



* [python-package] fix mypy errors in engine.py

* Apply suggestions from code review
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* allow for stdv

* whitespace
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent de23b562
......@@ -2,10 +2,15 @@
"""Callbacks library."""
import collections
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Tuple, Union
from .basic import _ConfigAliases, _log_info, _log_warning
_EvalResultTuple = Union[
List[Tuple[str, str, float, bool]],
List[Tuple[str, str, float, bool, float]]
]
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta
......@@ -18,15 +23,15 @@ def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
class EarlyStopException(Exception):
"""Exception of early stopping."""
def __init__(self, best_iteration: int, best_score: float) -> None:
def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
"""Create early stopping exception.
Parameters
----------
best_iteration : int
The best iteration stopped.
best_score : float
The score of the best iteration.
best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
Scores for each metric, on each validation set, as of the best iteration.
"""
super().__init__()
self.best_iteration = best_iteration
......@@ -44,7 +49,7 @@ CallbackEnv = collections.namedtuple(
"evaluation_result_list"])
def _format_eval_result(value: list, show_stdv: bool = True) -> str:
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
"""Format metric string."""
if len(value) == 4:
return f"{value[0]}'s {value[1]}: {value[2]:g}"
......
......@@ -223,38 +223,38 @@ def train(
name_valid_sets.append(f'valid_{i}')
# process callbacks
if callbacks is None:
callbacks = set()
callbacks_set = set()
else:
for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks)
callbacks_set = set(callbacks)
# Most of legacy advanced options becomes callbacks
if verbose_eval != "warn":
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'log_evaluation()' callback via 'callbacks' argument instead.")
else:
if callbacks: # assume user has already specified log_evaluation callback
if callbacks_set: # assume user has already specified log_evaluation callback
verbose_eval = False
else:
verbose_eval = True
if verbose_eval is True:
callbacks.add(callback.log_evaluation())
callbacks_set.add(callback.log_evaluation())
elif isinstance(verbose_eval, int):
callbacks.add(callback.log_evaluation(verbose_eval))
callbacks_set.add(callback.log_evaluation(verbose_eval))
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
callbacks_set.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
if evals_result is not None:
_log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'record_evaluation()' callback via 'callbacks' argument instead.")
callbacks.add(callback.record_evaluation(evals_result))
callbacks_set.add(callback.record_evaluation(evals_result))
callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
callbacks_after_iter = callbacks - callbacks_before_iter
callbacks_before_iter = sorted(callbacks_before_iter, key=attrgetter('order'))
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order'))
callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, 'before_iteration', False)}
callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set
callbacks_before_iter = sorted(callbacks_before_iter_set, key=attrgetter('order'))
callbacks_after_iter = sorted(callbacks_after_iter_set, key=attrgetter('order'))
# construct booster
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment