callback.py 15.5 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Tuple, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

9
10
11
12
13
_EvalResultTuple = Union[
    List[Tuple[str, str, float, bool]],
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
14

wxchan's avatar
wxchan committed
15
class EarlyStopException(Exception):
16
    """Exception of early stopping."""
17

18
    def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
19
20
21
22
23
24
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
25
26
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
27
        """
28
        super().__init__()
wxchan's avatar
wxchan committed
29
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
30
        self.best_score = best_score
wxchan's avatar
wxchan committed
31

wxchan's avatar
wxchan committed
32

wxchan's avatar
wxchan committed
33
34
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
35
    "CallbackEnv",
wxchan's avatar
wxchan committed
36
    ["model",
37
     "params",
wxchan's avatar
wxchan committed
38
39
40
41
42
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
43

44
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
45
    """Format metric string."""
wxchan's avatar
wxchan committed
46
    if len(value) == 4:
47
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
48
49
    elif len(value) == 5:
        if show_stdv:
50
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
51
        else:
52
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
53
    else:
54
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
55
56


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
74
75
    """Create a callback that logs the evaluation results.

76
77
78
79
80
81
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
82
83
84

    Parameters
    ----------
85
    period : int, optional (default=1)
86
87
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
88
    show_stdv : bool, optional (default=True)
89
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
90
91
92

    Returns
    -------
93
    callback : _LogEvaluationCallback
94
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
95
    """
96
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
97
98


99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

    def __init__(self, eval_result: Dict[str, Dict[str, List[Any]]]) -> None:
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
            self.eval_result.setdefault(data_name, collections.OrderedDict())
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
                res_mean, res_stdv = item[2], item[4]
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


138
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
139
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
140
141
142
143

    Parameters
    ----------
    eval_result : dict
144
145
146
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
147

148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
166
167
    Returns
    -------
168
    callback : _RecordEvaluationCallback
169
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
170
    """
171
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
172
173


174
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
175
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
176

Nikita Titov's avatar
Nikita Titov committed
177
178
179
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
180
181
182

    Parameters
    ----------
183
    **kwargs : value should be list or callable
184
        List of parameters for each boosting round
185
        or a callable that calculates the parameter in terms of
186
187
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
188
        If callable func, parameter = func(current_round).
189

wxchan's avatar
wxchan committed
190
191
    Returns
    -------
192
    callback : callable
193
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
194
    """
195
    def _callback(env: CallbackEnv) -> None:
196
        new_parameters = {}
197
198
199
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
200
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
201
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
202
            else:
203
204
205
206
207
208
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
209
210
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
211
    return _callback
wxchan's avatar
wxchan committed
212
213


214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
        self.best_score = []
        self.best_iter = []
        self.best_score_list = []
        self.cmp_op = []
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

    def _init(self, env: CallbackEnv) -> None:
        self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                               in _ConfigAliases.get("boosting"))
        if not self.enabled:
252
            _log_warning('Early stopping is not available in dart mode')
253
            return
254
        if not env.evaluation_result_list:
255
256
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
257

258
        if self.stopping_rounds <= 0:
259
260
            raise ValueError("stopping_rounds should be greater than zero.")

261
262
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
263

264
        self._reset_storages()
265

266
267
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
268
269
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
270
                raise ValueError('Values for early stopping min_delta must be non-negative.')
271
272
            if len(self.min_delta) == 0:
                if self.verbose:
273
274
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
275
276
277
278
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
279
            else:
280
                if len(self.min_delta) != n_metrics:
281
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
282
283
284
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
285
        else:
286
            if self.min_delta < 0:
287
                raise ValueError('Early stopping min_delta must be non-negative.')
288
289
290
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
291

292
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
293
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
294
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
295
296
            self.best_iter.append(0)
            self.best_score_list.append(None)
297
            if eval_ret[3]:  # greater is better
298
299
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
300
            else:
301
302
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
303

304
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
305
        if env.iteration == env.end_iteration - 1:
306
307
            if self.verbose:
                best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
308
                _log_info('Did not meet early stopping. '
309
310
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
311
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
312
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
313

314
    def __call__(self, env: CallbackEnv) -> None:
315
        if env.iteration == env.begin_iteration:
316
317
            self._init(env)
        if not self.enabled:
318
            return
319
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
320
            score = env.evaluation_result_list[i][2]
321
322
323
324
            if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
                self.best_score_list[i] = env.evaluation_result_list
325
326
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
327
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
328
329
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
330
331
                    or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                self._final_iteration_check(env, eval_name_splitted, i)
332
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
333
334
335
336
337
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
338
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)