callback.py 11 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
wxchan's avatar
wxchan committed
4
from operator import gt, lt
5
from typing import Any, Callable, Dict, List, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9

wxchan's avatar
wxchan committed
10
class EarlyStopException(Exception):
11
    """Exception of early stopping."""
12

13
    def __init__(self, best_iteration: int, best_score: float) -> None:
14
15
16
17
18
19
20
21
22
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
23
        super().__init__()
wxchan's avatar
wxchan committed
24
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
25
        self.best_score = best_score
wxchan's avatar
wxchan committed
26

wxchan's avatar
wxchan committed
27

wxchan's avatar
wxchan committed
28
29
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
30
    "CallbackEnv",
wxchan's avatar
wxchan committed
31
    ["model",
32
     "params",
wxchan's avatar
wxchan committed
33
34
35
36
37
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
38

39
def _format_eval_result(value: list, show_stdv: bool = True) -> str:
40
    """Format metric string."""
wxchan's avatar
wxchan committed
41
    if len(value) == 4:
42
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
43
44
    elif len(value) == 5:
        if show_stdv:
45
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
46
        else:
47
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
48
    else:
49
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
50
51


52
def print_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
53
54
    """Create a callback that logs the evaluation results.

55
56
57
58
59
60
61
62
63
64
    Deprecated, use ``log_evaluation()`` instead.
    """
    _log_warning("'print_evaluation()' callback is deprecated and will be removed in a future release of LightGBM. "
                 "Use 'log_evaluation()' callback instead.")
    return log_evaluation(period=period, show_stdv=show_stdv)


def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
    """Create a callback that logs the evaluation results.

65
66
67
68
69
70
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
71
72
73

    Parameters
    ----------
74
    period : int, optional (default=1)
75
76
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
77
    show_stdv : bool, optional (default=True)
78
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
79
80
81

    Returns
    -------
82
    callback : callable
83
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
84
    """
85
    def _callback(env: CallbackEnv) -> None:
86
87
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
88
            _log_info(f'[{env.iteration + 1}]\t{result}')
89
    _callback.order = 10  # type: ignore
90
    return _callback
wxchan's avatar
wxchan committed
91
92


93
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
94
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
95
96
97
98

    Parameters
    ----------
    eval_result : dict
99
100
101
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
102

103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
121
122
    Returns
    -------
123
    callback : callable
124
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
125
126
    """
    if not isinstance(eval_result, dict):
127
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
128
129
    eval_result.clear()

130
    def _init(env: CallbackEnv) -> None:
131
132
133
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
134

135
    def _callback(env: CallbackEnv) -> None:
136
        if not eval_result:
137
            _init(env)
wxchan's avatar
wxchan committed
138
139
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
140
    _callback.order = 20  # type: ignore
141
    return _callback
wxchan's avatar
wxchan committed
142
143


144
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
145
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
146

Nikita Titov's avatar
Nikita Titov committed
147
148
149
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
150
151
152

    Parameters
    ----------
153
    **kwargs : value should be list or callable
154
        List of parameters for each boosting round
155
        or a callable that calculates the parameter in terms of
156
157
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
158
        If callable func, parameter = func(current_round).
159

wxchan's avatar
wxchan committed
160
161
    Returns
    -------
162
    callback : callable
163
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
164
    """
165
    def _callback(env: CallbackEnv) -> None:
166
        new_parameters = {}
167
168
169
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
170
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
171
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
172
            else:
173
174
175
176
177
178
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
179
180
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
181
    return _callback
wxchan's avatar
wxchan committed
182
183


184
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True) -> Callable:
wxchan's avatar
wxchan committed
185
    """Create a callback that activates early stopping.
186

wxchan's avatar
wxchan committed
187
    Activates early stopping.
188
    The model will train until the validation score stops improving.
189
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
190
    to continue training.
191
    Requires at least one validation data and one metric.
192
    If there's more than one, will check all of them. But the training data is ignored anyway.
193
    To check only the first metric set ``first_metric_only`` to True.
194
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
wxchan's avatar
wxchan committed
195
196
197
198

    Parameters
    ----------
    stopping_rounds : int
199
        The possible number of rounds without the trend occurrence.
200
    first_metric_only : bool, optional (default=False)
201
        Whether to use only the first metric for early stopping.
202
    verbose : bool, optional (default=True)
203
204
205
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
wxchan's avatar
wxchan committed
206
207
208

    Returns
    -------
209
    callback : callable
210
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
211
    """
wxchan's avatar
wxchan committed
212
213
    best_score = []
    best_iter = []
214
    best_score_list: list = []
wxchan's avatar
wxchan committed
215
    cmp_op = []
216
    enabled = [True]
217
    first_metric = ['']
wxchan's avatar
wxchan committed
218

219
    def _init(env: CallbackEnv) -> None:
220
221
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
222
        if not enabled[0]:
223
            _log_warning('Early stopping is not available in dart mode')
224
            return
225
        if not env.evaluation_result_list:
226
227
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
228
229

        if verbose:
230
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
231

232
233
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
wxchan's avatar
wxchan committed
234
235
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
236
            best_score_list.append(None)
wxchan's avatar
wxchan committed
237
238
239
240
241
242
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
243

244
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
245
246
        if env.iteration == env.end_iteration - 1:
            if verbose:
247
248
249
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
250
                if first_metric_only:
251
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
252
253
            raise EarlyStopException(best_iter[i], best_score_list[i])

254
    def _callback(env: CallbackEnv) -> None:
wxchan's avatar
wxchan committed
255
        if not cmp_op:
256
            _init(env)
257
258
        if not enabled[0]:
            return
259
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
260
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
261
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
262
263
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
264
                best_score_list[i] = env.evaluation_result_list
265
266
267
268
269
270
271
272
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
273
274
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
275
276
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
277
                    if first_metric_only:
278
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
279
                raise EarlyStopException(best_iter[i], best_score_list[i])
280
            _final_iteration_check(env, eval_name_splitted, i)
281
    _callback.order = 30  # type: ignore
282
    return _callback