callback.py 18.5 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3
4
from collections import OrderedDict
from dataclasses import dataclass
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
wxchan's avatar
wxchan committed
7

8
9
from .basic import (Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType,
                    _LGBM_BoosterEvalMethodResultWithStandardDeviationType, _log_info, _log_warning)
10
11
12

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
13

14
__all__ = [
15
    'EarlyStopException',
16
17
18
19
20
21
    'early_stopping',
    'log_evaluation',
    'record_evaluation',
    'reset_parameter',
]

22
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
23
_EvalResultTuple = Union[
24
    _LGBM_BoosterEvalMethodResultType,
25
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType
26
27
]
_ListOfEvalResultTuples = Union[
28
    List[_LGBM_BoosterEvalMethodResultType],
29
    List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]
30
31
]

wxchan's avatar
wxchan committed
32

wxchan's avatar
wxchan committed
33
class EarlyStopException(Exception):
34
35
36
37
38
    """Exception of early stopping.

    Raise this from a callback passed in via keyword argument ``callbacks``
    in ``cv()`` or ``train()`` to trigger early stopping.
    """
39

40
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
41
42
43
44
45
46
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
47
            0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
48
49
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
50
        """
51
        super().__init__()
wxchan's avatar
wxchan committed
52
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
53
        self.best_score = best_score
wxchan's avatar
wxchan committed
54

wxchan's avatar
wxchan committed
55

wxchan's avatar
wxchan committed
56
# Callback environment used by callbacks
57
58
59
60
61
62
63
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
64
    evaluation_result_list: Optional[_ListOfEvalResultTuples]
wxchan's avatar
wxchan committed
65

wxchan's avatar
wxchan committed
66

67
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
68
    """Format metric string."""
wxchan's avatar
wxchan committed
69
    if len(value) == 4:
70
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
71
72
    elif len(value) == 5:
        if show_stdv:
73
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
74
        else:
75
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
76
    else:
77
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
78
79


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
97
98
    """Create a callback that logs the evaluation results.

99
100
101
102
103
104
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
105
106
107

    Parameters
    ----------
108
    period : int, optional (default=1)
109
110
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
111
    show_stdv : bool, optional (default=True)
112
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
113
114
115

    Returns
    -------
116
    callback : _LogEvaluationCallback
117
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
118
    """
119
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
120
121


122
123
124
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

125
    def __init__(self, eval_result: _EvalResultDict) -> None:
126
127
128
129
130
131
132
133
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
134
135
136
137
138
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
139
140
141
142
143
144
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
145
            self.eval_result.setdefault(data_name, OrderedDict())
146
147
148
149
150
151
152
153
154
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
155
156
157
158
159
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
160
161
162
163
164
165
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
166
167
                res_mean = item[2]
                res_stdv = item[4]
168
169
170
171
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


172
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
173
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
174
175
176
177

    Parameters
    ----------
    eval_result : dict
178
179
180
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
200
201
    Returns
    -------
202
    callback : _RecordEvaluationCallback
203
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
204
    """
205
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
206
207


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
                raise ValueError("Only list and callable values are supported "
                                 "as a mapping from boosting round index to new parameter value.")
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)


236
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
237
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
238

Nikita Titov's avatar
Nikita Titov committed
239
240
241
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
242
243
244

    Parameters
    ----------
245
    **kwargs : value should be list or callable
246
        List of parameters for each boosting round
247
        or a callable that calculates the parameter in terms of
248
249
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
250
        If callable func, parameter = func(current_round).
251

wxchan's avatar
wxchan committed
252
253
    Returns
    -------
254
    callback : _ResetParameterCallback
255
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
256
    """
257
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
258
259


260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
282
283
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
284
        self.best_score_list: List[_ListOfEvalResultTuples] = []
285
        self.cmp_op: List[Callable[[float, float], bool]] = []
286
287
288
289
290
291
292
293
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

294
295
296
    def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
        return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name

297
    def _init(self, env: CallbackEnv) -> None:
298
299
300
301
        if env.evaluation_result_list is None or env.evaluation_result_list == []:
            raise ValueError(
                "For early stopping, at least one dataset and eval metric is required for evaluation"
            )
302
303
304
305
306
307
308
309
310
        is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
        only_train_set = (
            len(env.evaluation_result_list) == 1
            and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                train_name=env.model._train_data_name)
        )
        self.enabled = not is_dart and not only_train_set
311
        if not self.enabled:
312
313
314
315
            if is_dart:
                _log_warning('Early stopping is not available in dart mode')
            elif only_train_set:
                _log_warning('Only training set found, disabling early stopping.')
316
            return
wxchan's avatar
wxchan committed
317

318
        if self.stopping_rounds <= 0:
319
320
            raise ValueError("stopping_rounds should be greater than zero.")

321
322
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
323

324
        self._reset_storages()
325

326
        n_metrics = len({m[1] for m in env.evaluation_result_list})
327
        n_datasets = len(env.evaluation_result_list) // n_metrics
328
329
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
330
                raise ValueError('Values for early stopping min_delta must be non-negative.')
331
332
            if len(self.min_delta) == 0:
                if self.verbose:
333
334
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
335
336
337
338
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
339
            else:
340
                if len(self.min_delta) != n_metrics:
341
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
342
343
344
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
345
        else:
346
            if self.min_delta < 0:
347
                raise ValueError('Early stopping min_delta must be non-negative.')
348
349
350
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
351

352
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
353
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
354
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
355
            self.best_iter.append(0)
356
            if eval_ret[3]:  # greater is better
357
358
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
359
            else:
360
361
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
362

363
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
364
        if env.iteration == env.end_iteration - 1:
365
            if self.verbose:
366
                best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
367
                _log_info('Did not meet early stopping. '
368
369
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
370
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
371
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
372

373
    def __call__(self, env: CallbackEnv) -> None:
374
        if env.iteration == env.begin_iteration:
375
376
            self._init(env)
        if not self.enabled:
377
            return
378
379
380
381
382
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "early_stopping() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
383
384
        # self.best_score_list is initialized to an empty list
        first_time_updating_best_score_list = (self.best_score_list == [])
385
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
386
            score = env.evaluation_result_list[i][2]
387
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
388
389
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
390
391
392
393
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
394
395
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
396
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
397
                continue  # use only the first metric for early stopping
398
            if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
399
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
400
401
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
402
                    eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
403
404
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
405
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

437
438
        .. versionadded:: 4.0.0

439
440
441
442
443
444
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)