callback.py 12.6 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9

10
11
12
13
14
15
16
17
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score > best_score + delta


def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score < best_score - delta


wxchan's avatar
wxchan committed
18
class EarlyStopException(Exception):
19
    """Exception of early stopping."""
20

21
    def __init__(self, best_iteration: int, best_score: float) -> None:
22
23
24
25
26
27
28
29
30
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
31
        super().__init__()
wxchan's avatar
wxchan committed
32
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
33
        self.best_score = best_score
wxchan's avatar
wxchan committed
34

wxchan's avatar
wxchan committed
35

wxchan's avatar
wxchan committed
36
37
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
38
    "CallbackEnv",
wxchan's avatar
wxchan committed
39
    ["model",
40
     "params",
wxchan's avatar
wxchan committed
41
42
43
44
45
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
46

47
def _format_eval_result(value: list, show_stdv: bool = True) -> str:
48
    """Format metric string."""
wxchan's avatar
wxchan committed
49
    if len(value) == 4:
50
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
51
52
    elif len(value) == 5:
        if show_stdv:
53
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
54
        else:
55
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
56
    else:
57
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
58
59


60
61
62
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
    """Create a callback that logs the evaluation results.

63
64
65
66
67
68
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
69
70
71

    Parameters
    ----------
72
    period : int, optional (default=1)
73
74
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
75
    show_stdv : bool, optional (default=True)
76
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
77
78
79

    Returns
    -------
80
    callback : callable
81
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
82
    """
83
    def _callback(env: CallbackEnv) -> None:
84
85
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
86
            _log_info(f'[{env.iteration + 1}]\t{result}')
87
    _callback.order = 10  # type: ignore
88
    return _callback
wxchan's avatar
wxchan committed
89
90


91
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
92
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
93
94
95
96

    Parameters
    ----------
    eval_result : dict
97
98
99
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
100

101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
119
120
    Returns
    -------
121
    callback : callable
122
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
123
124
    """
    if not isinstance(eval_result, dict):
125
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
126
127
    eval_result.clear()

128
    def _init(env: CallbackEnv) -> None:
129
130
131
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
132

133
    def _callback(env: CallbackEnv) -> None:
134
        if not eval_result:
135
            _init(env)
wxchan's avatar
wxchan committed
136
137
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
138
    _callback.order = 20  # type: ignore
139
    return _callback
wxchan's avatar
wxchan committed
140
141


142
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
143
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
144

Nikita Titov's avatar
Nikita Titov committed
145
146
147
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
148
149
150

    Parameters
    ----------
151
    **kwargs : value should be list or callable
152
        List of parameters for each boosting round
153
        or a callable that calculates the parameter in terms of
154
155
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
156
        If callable func, parameter = func(current_round).
157

wxchan's avatar
wxchan committed
158
159
    Returns
    -------
160
    callback : callable
161
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
162
    """
163
    def _callback(env: CallbackEnv) -> None:
164
        new_parameters = {}
165
166
167
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
168
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
169
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
170
            else:
171
172
173
174
175
176
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
177
178
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
179
    return _callback
wxchan's avatar
wxchan committed
180
181


182
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
wxchan's avatar
wxchan committed
183
    """Create a callback that activates early stopping.
184

wxchan's avatar
wxchan committed
185
    Activates early stopping.
186
    The model will train until the validation score doesn't improve by at least ``min_delta``.
187
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
188
    to continue training.
189
    Requires at least one validation data and one metric.
190
    If there's more than one, will check all of them. But the training data is ignored anyway.
191
    To check only the first metric set ``first_metric_only`` to True.
192
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
wxchan's avatar
wxchan committed
193
194
195
196

    Parameters
    ----------
    stopping_rounds : int
197
        The possible number of rounds without the trend occurrence.
198
    first_metric_only : bool, optional (default=False)
199
        Whether to use only the first metric for early stopping.
200
    verbose : bool, optional (default=True)
201
202
203
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
204
205
206
207
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.
wxchan's avatar
wxchan committed
208
209
210

    Returns
    -------
211
    callback : callable
212
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
213
    """
wxchan's avatar
wxchan committed
214
215
    best_score = []
    best_iter = []
216
    best_score_list: list = []
wxchan's avatar
wxchan committed
217
    cmp_op = []
218
    enabled = [True]
219
    first_metric = ['']
wxchan's avatar
wxchan committed
220

221
    def _init(env: CallbackEnv) -> None:
222
223
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
224
        if not enabled[0]:
225
            _log_warning('Early stopping is not available in dart mode')
226
            return
227
        if not env.evaluation_result_list:
228
229
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
230
231

        if verbose:
232
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
233

234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
        if isinstance(min_delta, list):
            if not all(t >= 0 for t in min_delta):
                raise ValueError('Values for early stopping min_delta must be non-negative.')
            if len(min_delta) == 0:
                if verbose:
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
            elif len(min_delta) == 1:
                if verbose:
                    _log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
                deltas = min_delta * n_datasets * n_metrics
            else:
                if len(min_delta) != n_metrics:
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
                if first_metric_only and verbose:
                    _log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
                deltas = min_delta * n_datasets
        else:
            if min_delta < 0:
                raise ValueError('Early stopping min_delta must be non-negative.')
            if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
                _log_info(f'Using {min_delta} as min_delta for all metrics.')
            deltas = [min_delta] * n_datasets * n_metrics

260
261
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
262
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
wxchan's avatar
wxchan committed
263
            best_iter.append(0)
wxchan's avatar
wxchan committed
264
            best_score_list.append(None)
265
            if eval_ret[3]:  # greater is better
wxchan's avatar
wxchan committed
266
                best_score.append(float('-inf'))
267
                cmp_op.append(partial(_gt_delta, delta=delta))
wxchan's avatar
wxchan committed
268
269
            else:
                best_score.append(float('inf'))
270
                cmp_op.append(partial(_lt_delta, delta=delta))
wxchan's avatar
wxchan committed
271

272
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
273
274
        if env.iteration == env.end_iteration - 1:
            if verbose:
275
276
277
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
278
                if first_metric_only:
279
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
280
281
            raise EarlyStopException(best_iter[i], best_score_list[i])

282
    def _callback(env: CallbackEnv) -> None:
wxchan's avatar
wxchan committed
283
        if not cmp_op:
284
            _init(env)
285
286
        if not enabled[0]:
            return
287
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
288
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
289
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
290
291
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
292
                best_score_list[i] = env.evaluation_result_list
293
294
295
296
297
298
299
300
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
301
302
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
303
304
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
305
                    if first_metric_only:
306
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
307
                raise EarlyStopException(best_iter[i], best_score_list[i])
308
            _final_iteration_check(env, eval_name_splitted, i)
309
    _callback.order = 30  # type: ignore
310
    return _callback