callback.py 19.6 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3
4
from collections import OrderedDict
from dataclasses import dataclass
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
wxchan's avatar
wxchan committed
7

8
9
10
11
12
13
14
15
from .basic import (
    Booster,
    _ConfigAliases,
    _LGBM_BoosterEvalMethodResultType,
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
    _log_info,
    _log_warning,
)
16
17
18

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
19

20
__all__ = [
21
22
23
24
25
    "EarlyStopException",
    "early_stopping",
    "log_evaluation",
    "record_evaluation",
    "reset_parameter",
26
27
]

28
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
29
_EvalResultTuple = Union[
30
    _LGBM_BoosterEvalMethodResultType,
31
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
32
33
]
_ListOfEvalResultTuples = Union[
34
    List[_LGBM_BoosterEvalMethodResultType],
35
    List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType],
36
37
]

wxchan's avatar
wxchan committed
38

wxchan's avatar
wxchan committed
39
class EarlyStopException(Exception):
40
41
42
43
44
    """Exception of early stopping.

    Raise this from a callback passed in via keyword argument ``callbacks``
    in ``cv()`` or ``train()`` to trigger early stopping.
    """
45

46
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
47
48
49
50
51
52
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
53
            0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
54
55
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
56
        """
57
        super().__init__()
wxchan's avatar
wxchan committed
58
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
59
        self.best_score = best_score
wxchan's avatar
wxchan committed
60

wxchan's avatar
wxchan committed
61

wxchan's avatar
wxchan committed
62
# Callback environment used by callbacks
63
64
65
66
67
68
69
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
70
    evaluation_result_list: Optional[_ListOfEvalResultTuples]
wxchan's avatar
wxchan committed
71

wxchan's avatar
wxchan committed
72

73
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
74
    """Format metric string."""
wxchan's avatar
wxchan committed
75
    if len(value) == 4:
76
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
77
78
    elif len(value) == 5:
        if show_stdv:
79
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
80
        else:
81
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
82
    else:
83
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
84
85


86
87
88
89
90
91
92
93
94
95
96
97
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
98
99
            result = "\t".join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f"[{env.iteration + 1}]\t{result}")
100
101
102


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
103
104
    """Create a callback that logs the evaluation results.

105
106
107
108
109
110
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
111
112
113

    Parameters
    ----------
114
    period : int, optional (default=1)
115
116
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
117
    show_stdv : bool, optional (default=True)
118
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
119
120
121

    Returns
    -------
122
    callback : _LogEvaluationCallback
123
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
124
    """
125
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
126
127


128
129
130
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

131
    def __init__(self, eval_result: _EvalResultDict) -> None:
132
133
134
135
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
136
            raise TypeError("eval_result should be a dictionary")
137
138
139
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
140
141
142
143
144
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
145
146
147
148
149
150
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
151
            self.eval_result.setdefault(data_name, OrderedDict())
152
153
154
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
155
156
                self.eval_result[data_name].setdefault(f"{eval_name}-mean", [])
                self.eval_result[data_name].setdefault(f"{eval_name}-stdv", [])
157
158
159
160

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
161
162
163
164
165
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
166
167
168
169
170
171
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
172
                res_mean = item[2]
173
                res_stdv = item[4]  # type: ignore[misc]
174
175
                self.eval_result[data_name][f"{eval_name}-mean"].append(res_mean)
                self.eval_result[data_name][f"{eval_name}-stdv"].append(res_stdv)
176
177


178
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
179
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
180
181
182
183

    Parameters
    ----------
    eval_result : dict
184
185
186
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
187

188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
206
207
    Returns
    -------
208
    callback : _RecordEvaluationCallback
209
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
210
    """
211
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
212
213


214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
233
234
235
236
                raise ValueError(
                    "Only list and callable values are supported "
                    "as a mapping from boosting round index to new parameter value."
                )
237
238
239
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
240
241
242
243
244
245
            if isinstance(env.model, Booster):
                env.model.reset_parameter(new_parameters)
            else:
                # CVBooster holds a list of Booster objects, each needs to be updated
                for booster in env.model.boosters:
                    booster.reset_parameter(new_parameters)
246
247
248
            env.params.update(new_parameters)


249
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
250
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
251

Nikita Titov's avatar
Nikita Titov committed
252
253
254
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
255
256
257

    Parameters
    ----------
258
    **kwargs : value should be list or callable
259
        List of parameters for each boosting round
260
        or a callable that calculates the parameter in terms of
261
262
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
263
        If callable func, parameter = func(current_round).
264

wxchan's avatar
wxchan committed
265
266
    Returns
    -------
267
    callback : _ResetParameterCallback
268
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
269
    """
270
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
271
272


273
274
275
276
277
278
279
280
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
281
        min_delta: Union[float, List[float]] = 0.0,
282
    ) -> None:
283
284
285
        if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
            raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")

286
287
288
289
290
291
292
293
294
295
296
297
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
298
299
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
300
        self.best_score_list: List[_ListOfEvalResultTuples] = []
301
        self.cmp_op: List[Callable[[float, float], bool]] = []
302
        self.first_metric = ""
303
304
305
306
307
308
309

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

310
311
312
313
314
315
316
317
318
319
320
321
    def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
        """Check, by name, if a given Dataset is the training data."""
        # for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
        # and those metrics are considered for early stopping
        if ds_name == "cv_agg" and eval_name == "train":
            return True

        # for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
        if isinstance(env.model, Booster) and ds_name == env.model._train_data_name:
            return True

        return False
322

323
    def _init(self, env: CallbackEnv) -> None:
324
        if env.evaluation_result_list is None or env.evaluation_result_list == []:
325
            raise ValueError("For early stopping, at least one dataset and eval metric is required for evaluation")
326

327
        is_dart = any(env.params.get(alias, "") == "dart" for alias in _ConfigAliases.get("boosting"))
328
329
        if is_dart:
            self.enabled = False
330
            _log_warning("Early stopping is not available in dart mode")
331
            return
wxchan's avatar
wxchan committed
332

333
334
        # validation sets are guaranteed to not be identical to the training data in cv()
        if isinstance(env.model, Booster):
335
336
337
338
            only_train_set = len(env.evaluation_result_list) == 1 and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                env=env,
339
340
341
            )
            if only_train_set:
                self.enabled = False
342
                _log_warning("Only training set found, disabling early stopping.")
343
                return
344

345
346
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
347

348
        self._reset_storages()
349

350
        n_metrics = len({m[1] for m in env.evaluation_result_list})
351
        n_datasets = len(env.evaluation_result_list) // n_metrics
352
353
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
354
                raise ValueError("Values for early stopping min_delta must be non-negative.")
355
356
            if len(self.min_delta) == 0:
                if self.verbose:
357
                    _log_info("Disabling min_delta for early stopping.")
358
                deltas = [0.0] * n_datasets * n_metrics
359
360
            elif len(self.min_delta) == 1:
                if self.verbose:
361
                    _log_info(f"Using {self.min_delta[0]} as min_delta for all metrics.")
362
                deltas = self.min_delta * n_datasets * n_metrics
363
            else:
364
                if len(self.min_delta) != n_metrics:
365
                    raise ValueError("Must provide a single value for min_delta or as many as metrics.")
366
                if self.first_metric_only and self.verbose:
367
                    _log_info(f"Using only {self.min_delta[0]} as early stopping min_delta.")
368
                deltas = self.min_delta * n_datasets
369
        else:
370
            if self.min_delta < 0:
371
                raise ValueError("Early stopping min_delta must be non-negative.")
372
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
373
                _log_info(f"Using {self.min_delta} as min_delta for all metrics.")
374
            deltas = [self.min_delta] * n_datasets * n_metrics
375

376
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
377
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
378
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
379
            self.best_iter.append(0)
380
            if eval_ret[3]:  # greater is better
381
                self.best_score.append(float("-inf"))
382
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
383
            else:
384
                self.best_score.append(float("inf"))
385
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
386

387
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
388
        if env.iteration == env.end_iteration - 1:
389
            if self.verbose:
390
391
392
393
                best_score_str = "\t".join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
                _log_info(
                    "Did not meet early stopping. " f"Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}"
                )
394
                if self.first_metric_only:
395
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
396
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
397

398
    def __call__(self, env: CallbackEnv) -> None:
399
        if env.iteration == env.begin_iteration:
400
401
            self._init(env)
        if not self.enabled:
402
            return
403
404
405
406
407
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "early_stopping() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
408
        # self.best_score_list is initialized to an empty list
409
        first_time_updating_best_score_list = self.best_score_list == []
410
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
411
            score = env.evaluation_result_list[i][2]
412
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
413
414
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
415
416
417
418
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
419
420
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
421
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
422
                continue  # use only the first metric for early stopping
423
424
425
            if self._is_train_set(
                ds_name=env.evaluation_result_list[i][0],
                eval_name=eval_name_splitted[0],
426
                env=env,
427
            ):
428
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
429
430
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
431
432
433
                    eval_result_str = "\t".join(
                        [_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]]
                    )
434
435
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
436
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
437
438
439
440
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


441
442
443
444
445
446
def early_stopping(
    stopping_rounds: int,
    first_metric_only: bool = False,
    verbose: bool = True,
    min_delta: Union[float, List[float]] = 0.0,
) -> _EarlyStoppingCallback:
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

473
474
        .. versionadded:: 4.0.0

475
476
477
478
479
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
480
481
482
483
484
485
    return _EarlyStoppingCallback(
        stopping_rounds=stopping_rounds,
        first_metric_only=first_metric_only,
        verbose=verbose,
        min_delta=min_delta,
    )