callback.py 17.6 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3
4
from collections import OrderedDict
from dataclasses import dataclass
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
wxchan's avatar
wxchan committed
7

8
9
from .basic import (Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType,
                    _LGBM_BoosterEvalMethodResultWithStandardDeviationType, _log_info, _log_warning)
10
11
12

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
13

14
__all__ = [
15
    'EarlyStopException',
16
17
18
19
20
21
    'early_stopping',
    'log_evaluation',
    'record_evaluation',
    'reset_parameter',
]

22
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
23
_EvalResultTuple = Union[
24
    _LGBM_BoosterEvalMethodResultType,
25
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType
26
27
]
_ListOfEvalResultTuples = Union[
28
    List[_LGBM_BoosterEvalMethodResultType],
29
    List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]
30
31
]

wxchan's avatar
wxchan committed
32

wxchan's avatar
wxchan committed
33
class EarlyStopException(Exception):
34
35
36
37
38
    """Exception of early stopping.

    Raise this from a callback passed in via keyword argument ``callbacks``
    in ``cv()`` or ``train()`` to trigger early stopping.
    """
39

40
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
41
42
43
44
45
46
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
47
            0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
48
49
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
50
        """
51
        super().__init__()
wxchan's avatar
wxchan committed
52
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
53
        self.best_score = best_score
wxchan's avatar
wxchan committed
54

wxchan's avatar
wxchan committed
55

wxchan's avatar
wxchan committed
56
# Callback environment used by callbacks
57
58
59
60
61
62
63
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
64
    evaluation_result_list: Optional[_ListOfEvalResultTuples]
wxchan's avatar
wxchan committed
65

wxchan's avatar
wxchan committed
66

67
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
68
    """Format metric string."""
wxchan's avatar
wxchan committed
69
    if len(value) == 4:
70
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
71
72
    elif len(value) == 5:
        if show_stdv:
73
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
74
        else:
75
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
76
    else:
77
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
78
79


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
97
98
    """Create a callback that logs the evaluation results.

99
100
101
102
103
104
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
105
106
107

    Parameters
    ----------
108
    period : int, optional (default=1)
109
110
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
111
    show_stdv : bool, optional (default=True)
112
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
113
114
115

    Returns
    -------
116
    callback : _LogEvaluationCallback
117
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
118
    """
119
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
120
121


122
123
124
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

125
    def __init__(self, eval_result: _EvalResultDict) -> None:
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
140
            self.eval_result.setdefault(data_name, OrderedDict())
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
156
157
                res_mean = item[2]
                res_stdv = item[4]
158
159
160
161
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


162
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
163
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
164
165
166
167

    Parameters
    ----------
    eval_result : dict
168
169
170
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
171

172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
190
191
    Returns
    -------
192
    callback : _RecordEvaluationCallback
193
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
194
    """
195
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
196
197


198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
                raise ValueError("Only list and callable values are supported "
                                 "as a mapping from boosting round index to new parameter value.")
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)


226
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
227
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
228

Nikita Titov's avatar
Nikita Titov committed
229
230
231
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
232
233
234

    Parameters
    ----------
235
    **kwargs : value should be list or callable
236
        List of parameters for each boosting round
237
        or a callable that calculates the parameter in terms of
238
239
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
240
        If callable func, parameter = func(current_round).
241

wxchan's avatar
wxchan committed
242
243
    Returns
    -------
244
    callback : _ResetParameterCallback
245
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
246
    """
247
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
248
249


250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
272
273
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
274
        self.best_score_list: List[_ListOfEvalResultTuples] = []
275
        self.cmp_op: List[Callable[[float, float], bool]] = []
276
277
278
279
280
281
282
283
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

284
285
286
    def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
        return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name

287
    def _init(self, env: CallbackEnv) -> None:
288
289
290
291
292
293
294
295
296
        is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
        only_train_set = (
            len(env.evaluation_result_list) == 1
            and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                train_name=env.model._train_data_name)
        )
        self.enabled = not is_dart and not only_train_set
297
        if not self.enabled:
298
299
300
301
            if is_dart:
                _log_warning('Early stopping is not available in dart mode')
            elif only_train_set:
                _log_warning('Only training set found, disabling early stopping.')
302
            return
303
        if not env.evaluation_result_list:
304
305
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
306

307
        if self.stopping_rounds <= 0:
308
309
            raise ValueError("stopping_rounds should be greater than zero.")

310
311
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
312

313
        self._reset_storages()
314

315
        n_metrics = len({m[1] for m in env.evaluation_result_list})
316
        n_datasets = len(env.evaluation_result_list) // n_metrics
317
318
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
319
                raise ValueError('Values for early stopping min_delta must be non-negative.')
320
321
            if len(self.min_delta) == 0:
                if self.verbose:
322
323
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
324
325
326
327
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
328
            else:
329
                if len(self.min_delta) != n_metrics:
330
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
331
332
333
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
334
        else:
335
            if self.min_delta < 0:
336
                raise ValueError('Early stopping min_delta must be non-negative.')
337
338
339
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
340

341
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
342
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
343
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
344
            self.best_iter.append(0)
345
            if eval_ret[3]:  # greater is better
346
347
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
348
            else:
349
350
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
351

352
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
353
        if env.iteration == env.end_iteration - 1:
354
            if self.verbose:
355
                best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
356
                _log_info('Did not meet early stopping. '
357
358
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
359
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
360
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
361

362
    def __call__(self, env: CallbackEnv) -> None:
363
        if env.iteration == env.begin_iteration:
364
365
            self._init(env)
        if not self.enabled:
366
            return
367
368
        # self.best_score_list is initialized to an empty list
        first_time_updating_best_score_list = (self.best_score_list == [])
369
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
370
            score = env.evaluation_result_list[i][2]
371
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
372
373
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
374
375
376
377
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
378
379
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
380
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
381
                continue  # use only the first metric for early stopping
382
            if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
383
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
384
385
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
386
                    eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
387
388
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
389
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

421
422
        .. versionadded:: 4.0.0

423
424
425
426
427
428
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)