"R-package/tests/vscode:/vscode.git/clone" did not exist on "3b6ebd794b82e02f8d5e1d0b915533bb4c36dbfc"
callback.py 14.8 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Tuple, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

9
10
11
12
13
_EvalResultTuple = Union[
    List[Tuple[str, str, float, bool]],
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
14

wxchan's avatar
wxchan committed
15
class EarlyStopException(Exception):
16
    """Exception of early stopping."""
17

18
    def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
19
20
21
22
23
24
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
25
26
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
27
        """
28
        super().__init__()
wxchan's avatar
wxchan committed
29
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
30
        self.best_score = best_score
wxchan's avatar
wxchan committed
31

wxchan's avatar
wxchan committed
32

wxchan's avatar
wxchan committed
33
34
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
35
    "CallbackEnv",
wxchan's avatar
wxchan committed
36
    ["model",
37
     "params",
wxchan's avatar
wxchan committed
38
39
40
41
42
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
43

44
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
45
    """Format metric string."""
wxchan's avatar
wxchan committed
46
    if len(value) == 4:
47
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
48
49
    elif len(value) == 5:
        if show_stdv:
50
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
51
        else:
52
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
53
    else:
54
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
55
56


57
58
59
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
    """Create a callback that logs the evaluation results.

60
61
62
63
64
65
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
66
67
68

    Parameters
    ----------
69
    period : int, optional (default=1)
70
71
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
72
    show_stdv : bool, optional (default=True)
73
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
74
75
76

    Returns
    -------
77
    callback : callable
78
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
79
    """
80
    def _callback(env: CallbackEnv) -> None:
81
82
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
83
            _log_info(f'[{env.iteration + 1}]\t{result}')
84
    _callback.order = 10  # type: ignore
85
    return _callback
wxchan's avatar
wxchan committed
86
87


88
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
89
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
90
91
92
93

    Parameters
    ----------
    eval_result : dict
94
95
96
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
97

98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
116
117
    Returns
    -------
118
    callback : callable
119
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
120
121
    """
    if not isinstance(eval_result, dict):
122
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
123

124
    def _init(env: CallbackEnv) -> None:
125
        eval_result.clear()
126
127
128
129
130
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
131
            eval_result.setdefault(data_name, collections.OrderedDict())
132
133
134
135
136
            if len(item) == 4:
                eval_result[data_name].setdefault(eval_name, [])
            else:
                eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                eval_result[data_name].setdefault(f'{eval_name}-stdv', [])
wxchan's avatar
wxchan committed
137

138
    def _callback(env: CallbackEnv) -> None:
139
        if env.iteration == env.begin_iteration:
140
            _init(env)
141
142
143
144
145
146
147
148
149
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
                res_mean, res_stdv = item[2], item[4]
                eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
150
    _callback.order = 20  # type: ignore
151
    return _callback
wxchan's avatar
wxchan committed
152
153


154
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
155
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
156

Nikita Titov's avatar
Nikita Titov committed
157
158
159
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
160
161
162

    Parameters
    ----------
163
    **kwargs : value should be list or callable
164
        List of parameters for each boosting round
165
        or a callable that calculates the parameter in terms of
166
167
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
168
        If callable func, parameter = func(current_round).
169

wxchan's avatar
wxchan committed
170
171
    Returns
    -------
172
    callback : callable
173
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
174
    """
175
    def _callback(env: CallbackEnv) -> None:
176
        new_parameters = {}
177
178
179
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
180
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
181
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
182
            else:
183
184
185
186
187
188
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
189
190
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
191
    return _callback
wxchan's avatar
wxchan committed
192
193


194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
        self.best_score = []
        self.best_iter = []
        self.best_score_list = []
        self.cmp_op = []
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

    def _init(self, env: CallbackEnv) -> None:
        self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                               in _ConfigAliases.get("boosting"))
        if not self.enabled:
232
            _log_warning('Early stopping is not available in dart mode')
233
            return
234
        if not env.evaluation_result_list:
235
236
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
237

238
        if self.stopping_rounds <= 0:
239
240
            raise ValueError("stopping_rounds should be greater than zero.")

241
242
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
243

244
        self._reset_storages()
245

246
247
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
248
249
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
250
                raise ValueError('Values for early stopping min_delta must be non-negative.')
251
252
            if len(self.min_delta) == 0:
                if self.verbose:
253
254
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
255
256
257
258
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
259
            else:
260
                if len(self.min_delta) != n_metrics:
261
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
262
263
264
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
265
        else:
266
            if self.min_delta < 0:
267
                raise ValueError('Early stopping min_delta must be non-negative.')
268
269
270
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
271

272
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
273
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
274
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
275
276
            self.best_iter.append(0)
            self.best_score_list.append(None)
277
            if eval_ret[3]:  # greater is better
278
279
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
280
            else:
281
282
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
283

284
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
285
        if env.iteration == env.end_iteration - 1:
286
287
            if self.verbose:
                best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
288
                _log_info('Did not meet early stopping. '
289
290
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
291
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
292
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
293

294
    def __call__(self, env: CallbackEnv) -> None:
295
        if env.iteration == env.begin_iteration:
296
297
            self._init(env)
        if not self.enabled:
298
            return
299
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
300
            score = env.evaluation_result_list[i][2]
301
302
303
304
            if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
                self.best_score_list[i] = env.evaluation_result_list
305
306
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
307
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
308
309
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
310
311
                    or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                self._final_iteration_check(env, eval_name_splitted, i)
312
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
313
314
315
316
317
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
318
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)