callback.py 16.7 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Tuple, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

9
10
11
12
13
14
15
__all__ = [
    'early_stopping',
    'log_evaluation',
    'record_evaluation',
    'reset_parameter',
]

16
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
17
_EvalResultTuple = Union[
18
    List[_LGBM_BoosterEvalMethodResultType],
19
20
21
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
22

wxchan's avatar
wxchan committed
23
class EarlyStopException(Exception):
24
    """Exception of early stopping."""
25

26
    def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
27
28
29
30
31
32
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
33
34
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
35
        """
36
        super().__init__()
wxchan's avatar
wxchan committed
37
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
38
        self.best_score = best_score
wxchan's avatar
wxchan committed
39

wxchan's avatar
wxchan committed
40

wxchan's avatar
wxchan committed
41
42
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
43
    "CallbackEnv",
wxchan's avatar
wxchan committed
44
    ["model",
45
     "params",
wxchan's avatar
wxchan committed
46
47
48
49
50
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
51

52
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
53
    """Format metric string."""
wxchan's avatar
wxchan committed
54
    if len(value) == 4:
55
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
56
57
    elif len(value) == 5:
        if show_stdv:
58
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
59
        else:
60
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
61
    else:
62
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
63
64


65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
82
83
    """Create a callback that logs the evaluation results.

84
85
86
87
88
89
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
90
91
92

    Parameters
    ----------
93
    period : int, optional (default=1)
94
95
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
96
    show_stdv : bool, optional (default=True)
97
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
98
99
100

    Returns
    -------
101
    callback : _LogEvaluationCallback
102
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
103
    """
104
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
105
106


107
108
109
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

110
    def __init__(self, eval_result: _EvalResultDict) -> None:
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
            self.eval_result.setdefault(data_name, collections.OrderedDict())
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
141
142
                res_mean = item[2]
                res_stdv = item[4]
143
144
145
146
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


147
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
148
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
149
150
151
152

    Parameters
    ----------
    eval_result : dict
153
154
155
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
156

157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
175
176
    Returns
    -------
177
    callback : _RecordEvaluationCallback
178
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
179
    """
180
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
181
182


183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
                raise ValueError("Only list and callable values are supported "
                                 "as a mapping from boosting round index to new parameter value.")
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)


211
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
212
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
213

Nikita Titov's avatar
Nikita Titov committed
214
215
216
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
217
218
219

    Parameters
    ----------
220
    **kwargs : value should be list or callable
221
        List of parameters for each boosting round
222
        or a callable that calculates the parameter in terms of
223
224
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
225
        If callable func, parameter = func(current_round).
226

wxchan's avatar
wxchan committed
227
228
    Returns
    -------
229
    callback : _ResetParameterCallback
230
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
231
    """
232
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
233
234


235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
257
258
259
260
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
        self.best_score_list: List[Union[_EvalResultTuple, None]] = []
        self.cmp_op: List[Callable[[float, float], bool]] = []
261
262
263
264
265
266
267
268
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

269
270
271
    def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
        return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name

272
    def _init(self, env: CallbackEnv) -> None:
273
274
275
276
277
278
279
280
281
        is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
        only_train_set = (
            len(env.evaluation_result_list) == 1
            and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                train_name=env.model._train_data_name)
        )
        self.enabled = not is_dart and not only_train_set
282
        if not self.enabled:
283
284
285
286
            if is_dart:
                _log_warning('Early stopping is not available in dart mode')
            elif only_train_set:
                _log_warning('Only training set found, disabling early stopping.')
287
            return
288
        if not env.evaluation_result_list:
289
290
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
291

292
        if self.stopping_rounds <= 0:
293
294
            raise ValueError("stopping_rounds should be greater than zero.")

295
296
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
297

298
        self._reset_storages()
299

300
301
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
302
303
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
304
                raise ValueError('Values for early stopping min_delta must be non-negative.')
305
306
            if len(self.min_delta) == 0:
                if self.verbose:
307
308
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
309
310
311
312
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
313
            else:
314
                if len(self.min_delta) != n_metrics:
315
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
316
317
318
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
319
        else:
320
            if self.min_delta < 0:
321
                raise ValueError('Early stopping min_delta must be non-negative.')
322
323
324
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
325

326
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
327
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
328
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
329
330
            self.best_iter.append(0)
            self.best_score_list.append(None)
331
            if eval_ret[3]:  # greater is better
332
333
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
334
            else:
335
336
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
337

338
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
339
        if env.iteration == env.end_iteration - 1:
340
            if self.verbose:
341
                best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
342
                _log_info('Did not meet early stopping. '
343
344
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
345
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
346
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
347

348
    def __call__(self, env: CallbackEnv) -> None:
349
        if env.iteration == env.begin_iteration:
350
351
            self._init(env)
        if not self.enabled:
352
            return
353
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
354
            score = env.evaluation_result_list[i][2]
355
356
357
358
            if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
                self.best_score_list[i] = env.evaluation_result_list
359
360
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
361
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
362
                continue  # use only the first metric for early stopping
363
            if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
364
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
365
366
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
367
                    eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
368
369
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
370
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)