callback.py 9.52 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
wxchan's avatar
wxchan committed
4
from operator import gt, lt
5
from typing import Any, Callable, Dict, List, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9

wxchan's avatar
wxchan committed
10
class EarlyStopException(Exception):
11
    """Exception of early stopping."""
12

13
    def __init__(self, best_iteration: int, best_score: float) -> None:
14
15
16
17
18
19
20
21
22
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
23
        super().__init__()
wxchan's avatar
wxchan committed
24
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
25
        self.best_score = best_score
wxchan's avatar
wxchan committed
26

wxchan's avatar
wxchan committed
27

wxchan's avatar
wxchan committed
28
29
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
30
    "CallbackEnv",
wxchan's avatar
wxchan committed
31
    ["model",
32
     "params",
wxchan's avatar
wxchan committed
33
34
35
36
37
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
38

39
def _format_eval_result(value: list, show_stdv: bool = True) -> str:
40
    """Format metric string."""
wxchan's avatar
wxchan committed
41
    if len(value) == 4:
42
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
43
44
    elif len(value) == 5:
        if show_stdv:
45
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
46
        else:
47
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
48
    else:
49
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
50
51


52
def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
53
    """Create a callback that prints the evaluation results.
wxchan's avatar
wxchan committed
54
55
56

    Parameters
    ----------
57
58
59
60
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
wxchan's avatar
wxchan committed
61
62
63
64

    Returns
    -------
    callback : function
65
        The callback that prints the evaluation results every ``period`` iteration(s).
wxchan's avatar
wxchan committed
66
    """
67
    def _callback(env: CallbackEnv) -> None:
68
69
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
70
            _log_info(f'[{env.iteration + 1}]\t{result}')
71
    _callback.order = 10  # type: ignore
72
    return _callback
wxchan's avatar
wxchan committed
73
74


75
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
76
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
77
78
79
80

    Parameters
    ----------
    eval_result : dict
81
82
83
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
84
85
86
87

    Returns
    -------
    callback : function
88
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
89
90
    """
    if not isinstance(eval_result, dict):
91
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
92
93
    eval_result.clear()

94
    def _init(env: CallbackEnv) -> None:
95
96
97
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
98

99
    def _callback(env: CallbackEnv) -> None:
100
        if not eval_result:
101
            _init(env)
wxchan's avatar
wxchan committed
102
103
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
104
    _callback.order = 20  # type: ignore
105
    return _callback
wxchan's avatar
wxchan committed
106
107


108
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
109
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
110

Nikita Titov's avatar
Nikita Titov committed
111
112
113
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
114
115
116

    Parameters
    ----------
117
    **kwargs : value should be list or function
118
        List of parameters for each boosting round
119
120
121
122
123
        or a customized function that calculates the parameter in terms of
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
        If function func, parameter = func(current_round).

wxchan's avatar
wxchan committed
124
125
126
    Returns
    -------
    callback : function
127
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
128
    """
129
    def _callback(env: CallbackEnv) -> None:
130
        new_parameters = {}
131
132
133
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
134
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
135
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
136
            else:
137
138
139
140
141
142
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
143
144
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
145
    return _callback
wxchan's avatar
wxchan committed
146
147


148
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
wxchan's avatar
wxchan committed
149
    """Create a callback that activates early stopping.
150

wxchan's avatar
wxchan committed
151
    Activates early stopping.
152
153
154
    The model will train until the validation score stops improving.
    Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
    to continue training.
155
    Requires at least one validation data and one metric.
156
    If there's more than one, will check all of them. But the training data is ignored anyway.
157
    To check only the first metric set ``first_metric_only`` to True.
wxchan's avatar
wxchan committed
158
159
160
161

    Parameters
    ----------
    stopping_rounds : int
162
        The possible number of rounds without the trend occurrence.
163
    first_metric_only : bool, optional (default=False)
164
        Whether to use only the first metric for early stopping.
165
166
    verbose : bool, optional (default=True)
        Whether to print message with early stopping information.
wxchan's avatar
wxchan committed
167
168
169
170

    Returns
    -------
    callback : function
171
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
172
    """
wxchan's avatar
wxchan committed
173
174
    best_score = []
    best_iter = []
175
    best_score_list: list = []
wxchan's avatar
wxchan committed
176
    cmp_op = []
177
    enabled = [True]
178
    first_metric = ['']
wxchan's avatar
wxchan committed
179

180
    def _init(env: CallbackEnv) -> None:
181
182
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
183
        if not enabled[0]:
184
            _log_warning('Early stopping is not available in dart mode')
185
            return
186
        if not env.evaluation_result_list:
187
188
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
189
190

        if verbose:
191
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
192

193
194
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
wxchan's avatar
wxchan committed
195
196
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
197
            best_score_list.append(None)
wxchan's avatar
wxchan committed
198
199
200
201
202
203
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
204

205
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
206
207
        if env.iteration == env.end_iteration - 1:
            if verbose:
208
209
210
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
211
                if first_metric_only:
212
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
213
214
            raise EarlyStopException(best_iter[i], best_score_list[i])

215
    def _callback(env: CallbackEnv) -> None:
wxchan's avatar
wxchan committed
216
        if not cmp_op:
217
            _init(env)
218
219
        if not enabled[0]:
            return
220
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
221
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
222
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
223
224
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
225
                best_score_list[i] = env.evaluation_result_list
226
227
228
229
230
231
232
233
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
234
235
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
236
237
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
238
                    if first_metric_only:
239
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
240
                raise EarlyStopException(best_iter[i], best_score_list[i])
241
            _final_iteration_check(env, eval_name_splitted, i)
242
    _callback.order = 30  # type: ignore
243
    return _callback