callback.py 17 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Tuple, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

9
10
11
12
13
14
15
__all__ = [
    'early_stopping',
    'log_evaluation',
    'record_evaluation',
    'reset_parameter',
]

16
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
17
_EvalResultTuple = Union[
18
19
20
21
    _LGBM_BoosterEvalMethodResultType,
    Tuple[str, str, float, bool, float]
]
_ListOfEvalResultTuples = Union[
22
    List[_LGBM_BoosterEvalMethodResultType],
23
24
25
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
26

wxchan's avatar
wxchan committed
27
class EarlyStopException(Exception):
28
    """Exception of early stopping."""
29

30
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
31
32
33
34
35
36
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
37
38
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
39
        """
40
        super().__init__()
wxchan's avatar
wxchan committed
41
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
42
        self.best_score = best_score
wxchan's avatar
wxchan committed
43

wxchan's avatar
wxchan committed
44

wxchan's avatar
wxchan committed
45
46
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
47
    "CallbackEnv",
wxchan's avatar
wxchan committed
48
    ["model",
49
     "params",
wxchan's avatar
wxchan committed
50
51
52
53
54
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
55

56
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
57
    """Format metric string."""
wxchan's avatar
wxchan committed
58
    if len(value) == 4:
59
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
60
61
    elif len(value) == 5:
        if show_stdv:
62
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
63
        else:
64
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
65
    else:
66
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
67
68


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
86
87
    """Create a callback that logs the evaluation results.

88
89
90
91
92
93
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
94
95
96

    Parameters
    ----------
97
    period : int, optional (default=1)
98
99
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
100
    show_stdv : bool, optional (default=True)
101
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
102
103
104

    Returns
    -------
105
    callback : _LogEvaluationCallback
106
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
107
    """
108
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
109
110


111
112
113
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

114
    def __init__(self, eval_result: _EvalResultDict) -> None:
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
            self.eval_result.setdefault(data_name, collections.OrderedDict())
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
145
146
                res_mean = item[2]
                res_stdv = item[4]
147
148
149
150
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


151
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
152
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
153
154
155
156

    Parameters
    ----------
    eval_result : dict
157
158
159
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
179
180
    Returns
    -------
181
    callback : _RecordEvaluationCallback
182
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
183
    """
184
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
185
186


187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
                raise ValueError("Only list and callable values are supported "
                                 "as a mapping from boosting round index to new parameter value.")
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)


215
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
216
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
217

Nikita Titov's avatar
Nikita Titov committed
218
219
220
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
221
222
223

    Parameters
    ----------
224
    **kwargs : value should be list or callable
225
        List of parameters for each boosting round
226
        or a callable that calculates the parameter in terms of
227
228
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
229
        If callable func, parameter = func(current_round).
230

wxchan's avatar
wxchan committed
231
232
    Returns
    -------
233
    callback : _ResetParameterCallback
234
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
235
    """
236
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
237
238


239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
261
262
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
263
        self.best_score_list: List[_ListOfEvalResultTuples] = []
264
        self.cmp_op: List[Callable[[float, float], bool]] = []
265
266
267
268
269
270
271
272
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

273
274
275
    def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
        return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name

276
    def _init(self, env: CallbackEnv) -> None:
277
278
279
280
281
282
283
284
285
        is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
        only_train_set = (
            len(env.evaluation_result_list) == 1
            and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                train_name=env.model._train_data_name)
        )
        self.enabled = not is_dart and not only_train_set
286
        if not self.enabled:
287
288
289
290
            if is_dart:
                _log_warning('Early stopping is not available in dart mode')
            elif only_train_set:
                _log_warning('Only training set found, disabling early stopping.')
291
            return
292
        if not env.evaluation_result_list:
293
294
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
295

296
        if self.stopping_rounds <= 0:
297
298
            raise ValueError("stopping_rounds should be greater than zero.")

299
300
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
301

302
        self._reset_storages()
303

304
        n_metrics = len({m[1] for m in env.evaluation_result_list})
305
        n_datasets = len(env.evaluation_result_list) // n_metrics
306
307
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
308
                raise ValueError('Values for early stopping min_delta must be non-negative.')
309
310
            if len(self.min_delta) == 0:
                if self.verbose:
311
312
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
313
314
315
316
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
317
            else:
318
                if len(self.min_delta) != n_metrics:
319
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
320
321
322
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
323
        else:
324
            if self.min_delta < 0:
325
                raise ValueError('Early stopping min_delta must be non-negative.')
326
327
328
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
329

330
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
331
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
332
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
333
            self.best_iter.append(0)
334
            if eval_ret[3]:  # greater is better
335
336
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
337
            else:
338
339
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
340

341
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
342
        if env.iteration == env.end_iteration - 1:
343
            if self.verbose:
344
                best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
345
                _log_info('Did not meet early stopping. '
346
347
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
348
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
349
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
350

351
    def __call__(self, env: CallbackEnv) -> None:
352
        if env.iteration == env.begin_iteration:
353
354
            self._init(env)
        if not self.enabled:
355
            return
356
357
        # self.best_score_list is initialized to an empty list
        first_time_updating_best_score_list = (self.best_score_list == [])
358
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
359
            score = env.evaluation_result_list[i][2]
360
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
361
362
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
363
364
365
366
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
367
368
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
369
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
370
                continue  # use only the first metric for early stopping
371
            if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
372
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
373
374
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
375
                    eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
376
377
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
378
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)