callback.py 13 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9

10
11
12
13
14
15
16
17
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score > best_score + delta


def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score < best_score - delta


wxchan's avatar
wxchan committed
18
class EarlyStopException(Exception):
19
    """Exception of early stopping."""
20

21
    def __init__(self, best_iteration: int, best_score: float) -> None:
22
23
24
25
26
27
28
29
30
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
31
        super().__init__()
wxchan's avatar
wxchan committed
32
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
33
        self.best_score = best_score
wxchan's avatar
wxchan committed
34

wxchan's avatar
wxchan committed
35

wxchan's avatar
wxchan committed
36
37
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
38
    "CallbackEnv",
wxchan's avatar
wxchan committed
39
    ["model",
40
     "params",
wxchan's avatar
wxchan committed
41
42
43
44
45
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
46

47
def _format_eval_result(value: list, show_stdv: bool = True) -> str:
48
    """Format metric string."""
wxchan's avatar
wxchan committed
49
    if len(value) == 4:
50
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
51
52
    elif len(value) == 5:
        if show_stdv:
53
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
54
        else:
55
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
56
    else:
57
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
58
59


60
def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
61
62
    """Create a callback that logs the evaluation results.

63
64
65
66
67
68
69
70
71
72
    Deprecated, use ``log_evaluation()`` instead.
    """
    _log_warning("'print_evaluation()' callback is deprecated and will be removed in a future release of LightGBM. "
                 "Use 'log_evaluation()' callback instead.")
    return log_evaluation(period=period, show_stdv=show_stdv)


def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
    """Create a callback that logs the evaluation results.

73
74
75
76
77
78
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
79
80
81

    Parameters
    ----------
82
    period : int, optional (default=1)
83
84
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
85
    show_stdv : bool, optional (default=True)
86
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
87
88
89

    Returns
    -------
90
    callback : callable
91
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
92
    """
93
    def _callback(env: CallbackEnv) -> None:
94
95
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
96
            _log_info(f'[{env.iteration + 1}]\t{result}')
97
    _callback.order = 10  # type: ignore
98
    return _callback
wxchan's avatar
wxchan committed
99
100


101
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
102
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
103
104
105
106

    Parameters
    ----------
    eval_result : dict
107
108
109
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
129
130
    Returns
    -------
131
    callback : callable
132
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
133
134
    """
    if not isinstance(eval_result, dict):
135
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
136
137
    eval_result.clear()

138
    def _init(env: CallbackEnv) -> None:
139
140
141
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
142

143
    def _callback(env: CallbackEnv) -> None:
144
        if not eval_result:
145
            _init(env)
wxchan's avatar
wxchan committed
146
147
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
148
    _callback.order = 20  # type: ignore
149
    return _callback
wxchan's avatar
wxchan committed
150
151


152
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
153
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
154

Nikita Titov's avatar
Nikita Titov committed
155
156
157
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
158
159
160

    Parameters
    ----------
161
    **kwargs : value should be list or callable
162
        List of parameters for each boosting round
163
        or a callable that calculates the parameter in terms of
164
165
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
166
        If callable func, parameter = func(current_round).
167

wxchan's avatar
wxchan committed
168
169
    Returns
    -------
170
    callback : callable
171
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
172
    """
173
    def _callback(env: CallbackEnv) -> None:
174
        new_parameters = {}
175
176
177
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
178
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
179
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
180
            else:
181
182
183
184
185
186
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
187
188
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
189
    return _callback
wxchan's avatar
wxchan committed
190
191


192
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
wxchan's avatar
wxchan committed
193
    """Create a callback that activates early stopping.
194

wxchan's avatar
wxchan committed
195
    Activates early stopping.
196
    The model will train until the validation score doesn't improve by at least ``min_delta``.
197
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
198
    to continue training.
199
    Requires at least one validation data and one metric.
200
    If there's more than one, will check all of them. But the training data is ignored anyway.
201
    To check only the first metric set ``first_metric_only`` to True.
202
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
wxchan's avatar
wxchan committed
203
204
205
206

    Parameters
    ----------
    stopping_rounds : int
207
        The possible number of rounds without the trend occurrence.
208
    first_metric_only : bool, optional (default=False)
209
        Whether to use only the first metric for early stopping.
210
    verbose : bool, optional (default=True)
211
212
213
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
214
215
216
217
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.
wxchan's avatar
wxchan committed
218
219
220

    Returns
    -------
221
    callback : callable
222
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
223
    """
wxchan's avatar
wxchan committed
224
225
    best_score = []
    best_iter = []
226
    best_score_list: list = []
wxchan's avatar
wxchan committed
227
    cmp_op = []
228
    enabled = [True]
229
    first_metric = ['']
wxchan's avatar
wxchan committed
230

231
    def _init(env: CallbackEnv) -> None:
232
233
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
234
        if not enabled[0]:
235
            _log_warning('Early stopping is not available in dart mode')
236
            return
237
        if not env.evaluation_result_list:
238
239
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
240
241

        if verbose:
242
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
243

244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
        if isinstance(min_delta, list):
            if not all(t >= 0 for t in min_delta):
                raise ValueError('Values for early stopping min_delta must be non-negative.')
            if len(min_delta) == 0:
                if verbose:
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
            elif len(min_delta) == 1:
                if verbose:
                    _log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
                deltas = min_delta * n_datasets * n_metrics
            else:
                if len(min_delta) != n_metrics:
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
                if first_metric_only and verbose:
                    _log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
                deltas = min_delta * n_datasets
        else:
            if min_delta < 0:
                raise ValueError('Early stopping min_delta must be non-negative.')
            if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
                _log_info(f'Using {min_delta} as min_delta for all metrics.')
            deltas = [min_delta] * n_datasets * n_metrics

270
271
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
272
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
wxchan's avatar
wxchan committed
273
            best_iter.append(0)
wxchan's avatar
wxchan committed
274
            best_score_list.append(None)
275
            if eval_ret[3]:  # greater is better
wxchan's avatar
wxchan committed
276
                best_score.append(float('-inf'))
277
                cmp_op.append(partial(_gt_delta, delta=delta))
wxchan's avatar
wxchan committed
278
279
            else:
                best_score.append(float('inf'))
280
                cmp_op.append(partial(_lt_delta, delta=delta))
wxchan's avatar
wxchan committed
281

282
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
283
284
        if env.iteration == env.end_iteration - 1:
            if verbose:
285
286
287
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
288
                if first_metric_only:
289
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
290
291
            raise EarlyStopException(best_iter[i], best_score_list[i])

292
    def _callback(env: CallbackEnv) -> None:
wxchan's avatar
wxchan committed
293
        if not cmp_op:
294
            _init(env)
295
296
        if not enabled[0]:
            return
297
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
298
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
299
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
300
301
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
302
                best_score_list[i] = env.evaluation_result_list
303
304
305
306
307
308
309
310
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
311
312
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
313
314
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
315
                    if first_metric_only:
316
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
317
                raise EarlyStopException(best_iter[i], best_score_list[i])
318
            _final_iteration_check(env, eval_name_splitted, i)
319
    _callback.order = 30  # type: ignore
320
    return _callback