callback.py 13.4 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Tuple, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

9
10
11
12
13
_EvalResultTuple = Union[
    List[Tuple[str, str, float, bool]],
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
14

15
16
17
18
19
20
21
22
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score > best_score + delta


def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score < best_score - delta


wxchan's avatar
wxchan committed
23
class EarlyStopException(Exception):
24
    """Exception of early stopping."""
25

26
    def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
27
28
29
30
31
32
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
33
34
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
35
        """
36
        super().__init__()
wxchan's avatar
wxchan committed
37
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
38
        self.best_score = best_score
wxchan's avatar
wxchan committed
39

wxchan's avatar
wxchan committed
40

wxchan's avatar
wxchan committed
41
42
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
43
    "CallbackEnv",
wxchan's avatar
wxchan committed
44
    ["model",
45
     "params",
wxchan's avatar
wxchan committed
46
47
48
49
50
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
51

52
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
53
    """Format metric string."""
wxchan's avatar
wxchan committed
54
    if len(value) == 4:
55
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
56
57
    elif len(value) == 5:
        if show_stdv:
58
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
59
        else:
60
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
61
    else:
62
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
63
64


65
66
67
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
    """Create a callback that logs the evaluation results.

68
69
70
71
72
73
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
74
75
76

    Parameters
    ----------
77
    period : int, optional (default=1)
78
79
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
80
    show_stdv : bool, optional (default=True)
81
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
82
83
84

    Returns
    -------
85
    callback : callable
86
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
87
    """
88
    def _callback(env: CallbackEnv) -> None:
89
90
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
91
            _log_info(f'[{env.iteration + 1}]\t{result}')
92
    _callback.order = 10  # type: ignore
93
    return _callback
wxchan's avatar
wxchan committed
94
95


96
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
97
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
98
99
100
101

    Parameters
    ----------
    eval_result : dict
102
103
104
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
124
125
    Returns
    -------
126
    callback : callable
127
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
128
129
    """
    if not isinstance(eval_result, dict):
130
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
131

132
    def _init(env: CallbackEnv) -> None:
133
        eval_result.clear()
134
135
136
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
137

138
    def _callback(env: CallbackEnv) -> None:
139
        if env.iteration == env.begin_iteration:
140
            _init(env)
wxchan's avatar
wxchan committed
141
142
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
143
    _callback.order = 20  # type: ignore
144
    return _callback
wxchan's avatar
wxchan committed
145
146


147
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
148
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
149

Nikita Titov's avatar
Nikita Titov committed
150
151
152
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
153
154
155

    Parameters
    ----------
156
    **kwargs : value should be list or callable
157
        List of parameters for each boosting round
158
        or a callable that calculates the parameter in terms of
159
160
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
161
        If callable func, parameter = func(current_round).
162

wxchan's avatar
wxchan committed
163
164
    Returns
    -------
165
    callback : callable
166
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
167
    """
168
    def _callback(env: CallbackEnv) -> None:
169
        new_parameters = {}
170
171
172
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
173
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
174
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
175
            else:
176
177
178
179
180
181
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
182
183
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
184
    return _callback
wxchan's avatar
wxchan committed
185
186


187
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
wxchan's avatar
wxchan committed
188
    """Create a callback that activates early stopping.
189

wxchan's avatar
wxchan committed
190
    Activates early stopping.
191
    The model will train until the validation score doesn't improve by at least ``min_delta``.
192
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
193
    to continue training.
194
    Requires at least one validation data and one metric.
195
    If there's more than one, will check all of them. But the training data is ignored anyway.
196
    To check only the first metric set ``first_metric_only`` to True.
197
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
wxchan's avatar
wxchan committed
198
199
200
201

    Parameters
    ----------
    stopping_rounds : int
202
        The possible number of rounds without the trend occurrence.
203
    first_metric_only : bool, optional (default=False)
204
        Whether to use only the first metric for early stopping.
205
    verbose : bool, optional (default=True)
206
207
208
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
209
210
211
212
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.
wxchan's avatar
wxchan committed
213
214
215

    Returns
    -------
216
    callback : callable
217
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
218
    """
wxchan's avatar
wxchan committed
219
220
    best_score = []
    best_iter = []
221
    best_score_list: list = []
wxchan's avatar
wxchan committed
222
    cmp_op = []
223
224
    enabled = True
    first_metric = ''
wxchan's avatar
wxchan committed
225

226
    def _init(env: CallbackEnv) -> None:
227
228
229
230
231
232
233
234
235
        nonlocal best_score
        nonlocal best_iter
        nonlocal best_score_list
        nonlocal cmp_op
        nonlocal enabled
        nonlocal first_metric
        enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                          in _ConfigAliases.get("boosting"))
        if not enabled:
236
            _log_warning('Early stopping is not available in dart mode')
237
            return
238
        if not env.evaluation_result_list:
239
240
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
241
242

        if verbose:
243
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
244

245
246
247
248
249
250
251
        # reset storages
        best_score = []
        best_iter = []
        best_score_list = []
        cmp_op = []
        first_metric = ''

252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
        if isinstance(min_delta, list):
            if not all(t >= 0 for t in min_delta):
                raise ValueError('Values for early stopping min_delta must be non-negative.')
            if len(min_delta) == 0:
                if verbose:
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
            elif len(min_delta) == 1:
                if verbose:
                    _log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
                deltas = min_delta * n_datasets * n_metrics
            else:
                if len(min_delta) != n_metrics:
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
                if first_metric_only and verbose:
                    _log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
                deltas = min_delta * n_datasets
        else:
            if min_delta < 0:
                raise ValueError('Early stopping min_delta must be non-negative.')
            if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
                _log_info(f'Using {min_delta} as min_delta for all metrics.')
            deltas = [min_delta] * n_datasets * n_metrics

278
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
279
        first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
280
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
wxchan's avatar
wxchan committed
281
            best_iter.append(0)
wxchan's avatar
wxchan committed
282
            best_score_list.append(None)
283
            if eval_ret[3]:  # greater is better
wxchan's avatar
wxchan committed
284
                best_score.append(float('-inf'))
285
                cmp_op.append(partial(_gt_delta, delta=delta))
wxchan's avatar
wxchan committed
286
287
            else:
                best_score.append(float('inf'))
288
                cmp_op.append(partial(_lt_delta, delta=delta))
wxchan's avatar
wxchan committed
289

290
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
291
292
        nonlocal best_iter
        nonlocal best_score_list
293
294
        if env.iteration == env.end_iteration - 1:
            if verbose:
295
296
297
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
298
                if first_metric_only:
299
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
300
301
            raise EarlyStopException(best_iter[i], best_score_list[i])

302
    def _callback(env: CallbackEnv) -> None:
303
304
305
306
307
308
        nonlocal best_score
        nonlocal best_iter
        nonlocal best_score_list
        nonlocal cmp_op
        nonlocal enabled
        nonlocal first_metric
309
        if env.iteration == env.begin_iteration:
310
            _init(env)
311
        if not enabled:
312
            return
313
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
314
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
315
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
316
317
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
318
                best_score_list[i] = env.evaluation_result_list
319
320
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
321
            if first_metric_only and first_metric != eval_name_splitted[-1]:
322
323
324
325
326
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
327
328
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
329
330
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
331
                    if first_metric_only:
332
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
333
                raise EarlyStopException(best_iter[i], best_score_list[i])
334
            _final_iteration_check(env, eval_name_splitted, i)
335
    _callback.order = 30  # type: ignore
336
    return _callback