callback.py 12.9 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Tuple, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

9
10
11
12
13
_EvalResultTuple = Union[
    List[Tuple[str, str, float, bool]],
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
14

15
16
17
18
19
20
21
22
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score > best_score + delta


def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score < best_score - delta


wxchan's avatar
wxchan committed
23
class EarlyStopException(Exception):
24
    """Exception of early stopping."""
25

26
    def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
27
28
29
30
31
32
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
33
34
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
35
        """
36
        super().__init__()
wxchan's avatar
wxchan committed
37
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
38
        self.best_score = best_score
wxchan's avatar
wxchan committed
39

wxchan's avatar
wxchan committed
40

wxchan's avatar
wxchan committed
41
42
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
43
    "CallbackEnv",
wxchan's avatar
wxchan committed
44
    ["model",
45
     "params",
wxchan's avatar
wxchan committed
46
47
48
49
50
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
51

52
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
53
    """Format metric string."""
wxchan's avatar
wxchan committed
54
    if len(value) == 4:
55
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
56
57
    elif len(value) == 5:
        if show_stdv:
58
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
59
        else:
60
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
61
    else:
62
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
63
64


65
66
67
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
    """Create a callback that logs the evaluation results.

68
69
70
71
72
73
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
74
75
76

    Parameters
    ----------
77
    period : int, optional (default=1)
78
79
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
80
    show_stdv : bool, optional (default=True)
81
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
82
83
84

    Returns
    -------
85
    callback : callable
86
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
87
    """
88
    def _callback(env: CallbackEnv) -> None:
89
90
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
91
            _log_info(f'[{env.iteration + 1}]\t{result}')
92
    _callback.order = 10  # type: ignore
93
    return _callback
wxchan's avatar
wxchan committed
94
95


96
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
97
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
98
99
100
101

    Parameters
    ----------
    eval_result : dict
102
103
104
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
124
125
    Returns
    -------
126
    callback : callable
127
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
128
129
    """
    if not isinstance(eval_result, dict):
130
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
131
132
    eval_result.clear()

133
    def _init(env: CallbackEnv) -> None:
134
135
136
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
137

138
    def _callback(env: CallbackEnv) -> None:
139
        if not eval_result:
140
            _init(env)
wxchan's avatar
wxchan committed
141
142
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
143
    _callback.order = 20  # type: ignore
144
    return _callback
wxchan's avatar
wxchan committed
145
146


147
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
148
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
149

Nikita Titov's avatar
Nikita Titov committed
150
151
152
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
153
154
155

    Parameters
    ----------
156
    **kwargs : value should be list or callable
157
        List of parameters for each boosting round
158
        or a callable that calculates the parameter in terms of
159
160
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
161
        If callable func, parameter = func(current_round).
162

wxchan's avatar
wxchan committed
163
164
    Returns
    -------
165
    callback : callable
166
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
167
    """
168
    def _callback(env: CallbackEnv) -> None:
169
        new_parameters = {}
170
171
172
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
173
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
174
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
175
            else:
176
177
178
179
180
181
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
182
183
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
184
    return _callback
wxchan's avatar
wxchan committed
185
186


187
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
wxchan's avatar
wxchan committed
188
    """Create a callback that activates early stopping.
189

wxchan's avatar
wxchan committed
190
    Activates early stopping.
191
    The model will train until the validation score doesn't improve by at least ``min_delta``.
192
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
193
    to continue training.
194
    Requires at least one validation data and one metric.
195
    If there's more than one, will check all of them. But the training data is ignored anyway.
196
    To check only the first metric set ``first_metric_only`` to True.
197
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
wxchan's avatar
wxchan committed
198
199
200
201

    Parameters
    ----------
    stopping_rounds : int
202
        The possible number of rounds without the trend occurrence.
203
    first_metric_only : bool, optional (default=False)
204
        Whether to use only the first metric for early stopping.
205
    verbose : bool, optional (default=True)
206
207
208
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
209
210
211
212
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.
wxchan's avatar
wxchan committed
213
214
215

    Returns
    -------
216
    callback : callable
217
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
218
    """
wxchan's avatar
wxchan committed
219
220
    best_score = []
    best_iter = []
221
    best_score_list: list = []
wxchan's avatar
wxchan committed
222
    cmp_op = []
223
    enabled = [True]
224
    first_metric = ['']
wxchan's avatar
wxchan committed
225

226
    def _init(env: CallbackEnv) -> None:
227
228
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
229
        if not enabled[0]:
230
            _log_warning('Early stopping is not available in dart mode')
231
            return
232
        if not env.evaluation_result_list:
233
234
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
235
236

        if verbose:
237
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
238

239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
        if isinstance(min_delta, list):
            if not all(t >= 0 for t in min_delta):
                raise ValueError('Values for early stopping min_delta must be non-negative.')
            if len(min_delta) == 0:
                if verbose:
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
            elif len(min_delta) == 1:
                if verbose:
                    _log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
                deltas = min_delta * n_datasets * n_metrics
            else:
                if len(min_delta) != n_metrics:
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
                if first_metric_only and verbose:
                    _log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
                deltas = min_delta * n_datasets
        else:
            if min_delta < 0:
                raise ValueError('Early stopping min_delta must be non-negative.')
            if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
                _log_info(f'Using {min_delta} as min_delta for all metrics.')
            deltas = [min_delta] * n_datasets * n_metrics

265
266
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
267
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
wxchan's avatar
wxchan committed
268
            best_iter.append(0)
wxchan's avatar
wxchan committed
269
            best_score_list.append(None)
270
            if eval_ret[3]:  # greater is better
wxchan's avatar
wxchan committed
271
                best_score.append(float('-inf'))
272
                cmp_op.append(partial(_gt_delta, delta=delta))
wxchan's avatar
wxchan committed
273
274
            else:
                best_score.append(float('inf'))
275
                cmp_op.append(partial(_lt_delta, delta=delta))
wxchan's avatar
wxchan committed
276

277
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
278
279
        if env.iteration == env.end_iteration - 1:
            if verbose:
280
281
282
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
283
                if first_metric_only:
284
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
285
286
            raise EarlyStopException(best_iter[i], best_score_list[i])

287
    def _callback(env: CallbackEnv) -> None:
wxchan's avatar
wxchan committed
288
        if not cmp_op:
289
            _init(env)
290
291
        if not enabled[0]:
            return
292
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
293
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
294
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
295
296
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
297
                best_score_list[i] = env.evaluation_result_list
298
299
300
301
302
303
304
305
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
306
307
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
308
309
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
310
                    if first_metric_only:
311
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
312
                raise EarlyStopException(best_iter[i], best_score_list[i])
313
            _final_iteration_check(env, eval_name_splitted, i)
314
    _callback.order = 30  # type: ignore
315
    return _callback