callback.py 17.3 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3
4
from collections import OrderedDict
from dataclasses import dataclass
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
wxchan's avatar
wxchan committed
7

8
9
10
11
from .basic import Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType, _log_info, _log_warning

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
12

13
14
15
16
17
18
19
__all__ = [
    'early_stopping',
    'log_evaluation',
    'record_evaluation',
    'reset_parameter',
]

20
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
21
_EvalResultTuple = Union[
22
23
24
25
    _LGBM_BoosterEvalMethodResultType,
    Tuple[str, str, float, bool, float]
]
_ListOfEvalResultTuples = Union[
26
    List[_LGBM_BoosterEvalMethodResultType],
27
28
29
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
30

wxchan's avatar
wxchan committed
31
class EarlyStopException(Exception):
32
    """Exception of early stopping."""
33

34
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
35
36
37
38
39
40
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
41
42
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
43
        """
44
        super().__init__()
wxchan's avatar
wxchan committed
45
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
46
        self.best_score = best_score
wxchan's avatar
wxchan committed
47

wxchan's avatar
wxchan committed
48

wxchan's avatar
wxchan committed
49
# Callback environment used by callbacks
50
51
52
53
54
55
56
57
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
    evaluation_result_list: Optional[List[_LGBM_BoosterEvalMethodResultType]]
wxchan's avatar
wxchan committed
58

wxchan's avatar
wxchan committed
59

60
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
61
    """Format metric string."""
wxchan's avatar
wxchan committed
62
    if len(value) == 4:
63
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
64
65
    elif len(value) == 5:
        if show_stdv:
66
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
67
        else:
68
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
69
    else:
70
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
71
72


73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
90
91
    """Create a callback that logs the evaluation results.

92
93
94
95
96
97
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
98
99
100

    Parameters
    ----------
101
    period : int, optional (default=1)
102
103
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
104
    show_stdv : bool, optional (default=True)
105
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
106
107
108

    Returns
    -------
109
    callback : _LogEvaluationCallback
110
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
111
    """
112
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
113
114


115
116
117
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

118
    def __init__(self, eval_result: _EvalResultDict) -> None:
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
133
            self.eval_result.setdefault(data_name, OrderedDict())
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
149
150
                res_mean = item[2]
                res_stdv = item[4]
151
152
153
154
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


155
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
156
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
157
158
159
160

    Parameters
    ----------
    eval_result : dict
161
162
163
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
164

165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
183
184
    Returns
    -------
185
    callback : _RecordEvaluationCallback
186
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
187
    """
188
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
189
190


191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
                raise ValueError("Only list and callable values are supported "
                                 "as a mapping from boosting round index to new parameter value.")
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)


219
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
220
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
221

Nikita Titov's avatar
Nikita Titov committed
222
223
224
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
225
226
227

    Parameters
    ----------
228
    **kwargs : value should be list or callable
229
        List of parameters for each boosting round
230
        or a callable that calculates the parameter in terms of
231
232
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
233
        If callable func, parameter = func(current_round).
234

wxchan's avatar
wxchan committed
235
236
    Returns
    -------
237
    callback : _ResetParameterCallback
238
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
239
    """
240
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
241
242


243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
265
266
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
267
        self.best_score_list: List[_ListOfEvalResultTuples] = []
268
        self.cmp_op: List[Callable[[float, float], bool]] = []
269
270
271
272
273
274
275
276
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

277
278
279
    def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
        return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name

280
    def _init(self, env: CallbackEnv) -> None:
281
282
283
284
285
286
287
288
289
        is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
        only_train_set = (
            len(env.evaluation_result_list) == 1
            and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                train_name=env.model._train_data_name)
        )
        self.enabled = not is_dart and not only_train_set
290
        if not self.enabled:
291
292
293
294
            if is_dart:
                _log_warning('Early stopping is not available in dart mode')
            elif only_train_set:
                _log_warning('Only training set found, disabling early stopping.')
295
            return
296
        if not env.evaluation_result_list:
297
298
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
299

300
        if self.stopping_rounds <= 0:
301
302
            raise ValueError("stopping_rounds should be greater than zero.")

303
304
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
305

306
        self._reset_storages()
307

308
        n_metrics = len({m[1] for m in env.evaluation_result_list})
309
        n_datasets = len(env.evaluation_result_list) // n_metrics
310
311
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
312
                raise ValueError('Values for early stopping min_delta must be non-negative.')
313
314
            if len(self.min_delta) == 0:
                if self.verbose:
315
316
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
317
318
319
320
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
321
            else:
322
                if len(self.min_delta) != n_metrics:
323
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
324
325
326
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
327
        else:
328
            if self.min_delta < 0:
329
                raise ValueError('Early stopping min_delta must be non-negative.')
330
331
332
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
333

334
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
335
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
336
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
337
            self.best_iter.append(0)
338
            if eval_ret[3]:  # greater is better
339
340
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
341
            else:
342
343
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
344

345
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
346
        if env.iteration == env.end_iteration - 1:
347
            if self.verbose:
348
                best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
349
                _log_info('Did not meet early stopping. '
350
351
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
352
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
353
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
354

355
    def __call__(self, env: CallbackEnv) -> None:
356
        if env.iteration == env.begin_iteration:
357
358
            self._init(env)
        if not self.enabled:
359
            return
360
361
        # self.best_score_list is initialized to an empty list
        first_time_updating_best_score_list = (self.best_score_list == [])
362
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
363
            score = env.evaluation_result_list[i][2]
364
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
365
366
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
367
368
369
370
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
371
372
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
373
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
374
                continue  # use only the first metric for early stopping
375
            if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
376
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
377
378
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
379
                    eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
380
381
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
382
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

414
415
        .. versionadded:: 4.0.0

416
417
418
419
420
421
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)