callback.py 10.5 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
wxchan's avatar
wxchan committed
4
from operator import gt, lt
5
from typing import Any, Callable, Dict, List, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9

wxchan's avatar
wxchan committed
10
class EarlyStopException(Exception):
11
    """Exception of early stopping."""
12

13
    def __init__(self, best_iteration: int, best_score: float) -> None:
14
15
16
17
18
19
20
21
22
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
23
        super().__init__()
wxchan's avatar
wxchan committed
24
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
25
        self.best_score = best_score
wxchan's avatar
wxchan committed
26

wxchan's avatar
wxchan committed
27

wxchan's avatar
wxchan committed
28
29
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
30
    "CallbackEnv",
wxchan's avatar
wxchan committed
31
    ["model",
32
     "params",
wxchan's avatar
wxchan committed
33
34
35
36
37
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
38

39
def _format_eval_result(value: list, show_stdv: bool = True) -> str:
40
    """Format metric string."""
wxchan's avatar
wxchan committed
41
    if len(value) == 4:
42
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
43
44
    elif len(value) == 5:
        if show_stdv:
45
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
46
        else:
47
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
48
    else:
49
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
50
51


52
def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
53
54
55
56
57
58
59
60
    """Create a callback that logs the evaluation results.

    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
61
62
63

    Parameters
    ----------
64
    period : int, optional (default=1)
65
66
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
67
    show_stdv : bool, optional (default=True)
68
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
69
70
71

    Returns
    -------
72
    callback : callable
73
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
74
    """
75
    def _callback(env: CallbackEnv) -> None:
76
77
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
78
            _log_info(f'[{env.iteration + 1}]\t{result}')
79
    _callback.order = 10  # type: ignore
80
    return _callback
wxchan's avatar
wxchan committed
81
82


83
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
84
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
85
86
87
88

    Parameters
    ----------
    eval_result : dict
89
90
91
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
111
112
    Returns
    -------
113
    callback : callable
114
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
115
116
    """
    if not isinstance(eval_result, dict):
117
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
118
119
    eval_result.clear()

120
    def _init(env: CallbackEnv) -> None:
121
122
123
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
124

125
    def _callback(env: CallbackEnv) -> None:
126
        if not eval_result:
127
            _init(env)
wxchan's avatar
wxchan committed
128
129
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
130
    _callback.order = 20  # type: ignore
131
    return _callback
wxchan's avatar
wxchan committed
132
133


134
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
135
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
136

Nikita Titov's avatar
Nikita Titov committed
137
138
139
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
140
141
142

    Parameters
    ----------
143
    **kwargs : value should be list or callable
144
        List of parameters for each boosting round
145
        or a callable that calculates the parameter in terms of
146
147
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
148
        If callable func, parameter = func(current_round).
149

wxchan's avatar
wxchan committed
150
151
    Returns
    -------
152
    callback : callable
153
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
154
    """
155
    def _callback(env: CallbackEnv) -> None:
156
        new_parameters = {}
157
158
159
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
160
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
161
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
162
            else:
163
164
165
166
167
168
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
169
170
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
171
    return _callback
wxchan's avatar
wxchan committed
172
173


174
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
wxchan's avatar
wxchan committed
175
    """Create a callback that activates early stopping.
176

wxchan's avatar
wxchan committed
177
    Activates early stopping.
178
    The model will train until the validation score stops improving.
179
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
180
    to continue training.
181
    Requires at least one validation data and one metric.
182
    If there's more than one, will check all of them. But the training data is ignored anyway.
183
    To check only the first metric set ``first_metric_only`` to True.
184
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
wxchan's avatar
wxchan committed
185
186
187
188

    Parameters
    ----------
    stopping_rounds : int
189
        The possible number of rounds without the trend occurrence.
190
    first_metric_only : bool, optional (default=False)
191
        Whether to use only the first metric for early stopping.
192
    verbose : bool, optional (default=True)
193
194
195
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
wxchan's avatar
wxchan committed
196
197
198

    Returns
    -------
199
    callback : callable
200
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
201
    """
wxchan's avatar
wxchan committed
202
203
    best_score = []
    best_iter = []
204
    best_score_list: list = []
wxchan's avatar
wxchan committed
205
    cmp_op = []
206
    enabled = [True]
207
    first_metric = ['']
wxchan's avatar
wxchan committed
208

209
    def _init(env: CallbackEnv) -> None:
210
211
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
212
        if not enabled[0]:
213
            _log_warning('Early stopping is not available in dart mode')
214
            return
215
        if not env.evaluation_result_list:
216
217
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
218
219

        if verbose:
220
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
221

222
223
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
wxchan's avatar
wxchan committed
224
225
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
226
            best_score_list.append(None)
wxchan's avatar
wxchan committed
227
228
229
230
231
232
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
233

234
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
235
236
        if env.iteration == env.end_iteration - 1:
            if verbose:
237
238
239
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
240
                if first_metric_only:
241
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
242
243
            raise EarlyStopException(best_iter[i], best_score_list[i])

244
    def _callback(env: CallbackEnv) -> None:
wxchan's avatar
wxchan committed
245
        if not cmp_op:
246
            _init(env)
247
248
        if not enabled[0]:
            return
249
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
250
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
251
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
252
253
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
254
                best_score_list[i] = env.evaluation_result_list
255
256
257
258
259
260
261
262
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
263
264
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
265
266
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
267
                    if first_metric_only:
268
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
269
                raise EarlyStopException(best_iter[i], best_score_list[i])
270
            _final_iteration_check(env, eval_name_splitted, i)
271
    _callback.order = 30  # type: ignore
272
    return _callback