callback.py 20 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3

4
5
from collections import OrderedDict
from dataclasses import dataclass
6
from functools import partial
7
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
wxchan's avatar
wxchan committed
8

9
10
11
12
13
14
15
16
from .basic import (
    Booster,
    _ConfigAliases,
    _LGBM_BoosterEvalMethodResultType,
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
    _log_info,
    _log_warning,
)
17
18
19

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
20

21
__all__ = [
22
23
24
25
26
    "EarlyStopException",
    "early_stopping",
    "log_evaluation",
    "record_evaluation",
    "reset_parameter",
27
28
]

29
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
30
_EvalResultTuple = Union[
31
    _LGBM_BoosterEvalMethodResultType,
32
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
33
34
]
_ListOfEvalResultTuples = Union[
35
    List[_LGBM_BoosterEvalMethodResultType],
36
    List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType],
37
38
]

wxchan's avatar
wxchan committed
39

wxchan's avatar
wxchan committed
40
class EarlyStopException(Exception):
41
42
43
44
45
    """Exception of early stopping.

    Raise this from a callback passed in via keyword argument ``callbacks``
    in ``cv()`` or ``train()`` to trigger early stopping.
    """
46

47
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
48
49
50
51
52
53
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
54
            0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
55
56
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
57
        """
58
        super().__init__()
wxchan's avatar
wxchan committed
59
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
60
        self.best_score = best_score
wxchan's avatar
wxchan committed
61

wxchan's avatar
wxchan committed
62

wxchan's avatar
wxchan committed
63
# Callback environment used by callbacks
64
65
66
67
68
69
70
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
71
    evaluation_result_list: Optional[_ListOfEvalResultTuples]
wxchan's avatar
wxchan committed
72

wxchan's avatar
wxchan committed
73

74
75
76
def _is_using_cv(env: CallbackEnv) -> bool:
    """Check if model in callback env is a CVBooster."""
    # this import is here to avoid a circular import
77
    from .engine import CVBooster  # noqa: PLC0415
78
79
80
81

    return isinstance(env.model, CVBooster)


82
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
83
    """Format metric string."""
84
85
86
87
88
89
90
    dataset_name, metric_name, metric_value, *_ = value
    out = f"{dataset_name}'s {metric_name}: {metric_value:g}"
    # tuples from cv() sometimes have a 5th item, with standard deviation of
    # the evaluation metric (taken over all cross-validation folds)
    if show_stdv and len(value) == 5:
        out += f" + {value[4]:g}"
    return out
wxchan's avatar
wxchan committed
91
92


93
94
95
96
97
98
99
100
101
102
103
104
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
105
106
            result = "\t".join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f"[{env.iteration + 1}]\t{result}")
107
108
109


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
110
111
    """Create a callback that logs the evaluation results.

112
113
114
115
116
117
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
118
119
120

    Parameters
    ----------
121
    period : int, optional (default=1)
122
123
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
124
    show_stdv : bool, optional (default=True)
125
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
126
127
128

    Returns
    -------
129
    callback : _LogEvaluationCallback
130
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
131
    """
132
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
133
134


135
136
137
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

138
    def __init__(self, eval_result: _EvalResultDict) -> None:
139
140
141
142
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
143
            raise TypeError("eval_result should be a dictionary")
144
145
146
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
147
148
149
150
151
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
152
153
        self.eval_result.clear()
        for item in env.evaluation_result_list:
154
155
            dataset_name, metric_name, *_ = item
            self.eval_result.setdefault(dataset_name, OrderedDict())
156
            if len(item) == 4:
157
                self.eval_result[dataset_name].setdefault(metric_name, [])
158
            else:
159
160
                self.eval_result[dataset_name].setdefault(f"{metric_name}-mean", [])
                self.eval_result[dataset_name].setdefault(f"{metric_name}-stdv", [])
161
162
163
164

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
165
166
167
168
169
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
170
        for item in env.evaluation_result_list:
171
172
            # for cv(), 'metric_value' is actually a mean of metric values over all CV folds
            dataset_name, metric_name, metric_value, *_ = item
173
            if len(item) == 4:
174
175
                # train()
                self.eval_result[dataset_name][metric_name].append(metric_value)
176
            else:
177
178
179
180
                # cv()
                metric_std_dev = item[4]  # type: ignore[misc]
                self.eval_result[dataset_name][f"{metric_name}-mean"].append(metric_value)
                self.eval_result[dataset_name][f"{metric_name}-stdv"].append(metric_std_dev)
181
182


183
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
184
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
185
186
187
188

    Parameters
    ----------
    eval_result : dict
189
190
191
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
192

193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
211
212
    Returns
    -------
213
    callback : _RecordEvaluationCallback
214
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
215
    """
216
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
217
218


219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
238
239
240
241
                raise ValueError(
                    "Only list and callable values are supported "
                    "as a mapping from boosting round index to new parameter value."
                )
242
243
244
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
245
246
247
248
249
250
            if isinstance(env.model, Booster):
                env.model.reset_parameter(new_parameters)
            else:
                # CVBooster holds a list of Booster objects, each needs to be updated
                for booster in env.model.boosters:
                    booster.reset_parameter(new_parameters)
251
252
253
            env.params.update(new_parameters)


254
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
255
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
256

Nikita Titov's avatar
Nikita Titov committed
257
258
259
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
260
261
262

    Parameters
    ----------
263
    **kwargs : value should be list or callable
264
        List of parameters for each boosting round
265
        or a callable that calculates the parameter in terms of
266
267
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
268
        If callable func, parameter = func(current_round).
269

wxchan's avatar
wxchan committed
270
271
    Returns
    -------
272
    callback : _ResetParameterCallback
273
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
274
    """
275
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
276
277


278
279
280
281
282
283
284
285
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
286
        min_delta: Union[float, List[float]] = 0.0,
287
    ) -> None:
288
        self.enabled = _should_enable_early_stopping(stopping_rounds)
289

290
291
292
293
294
295
296
297
298
299
300
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self._reset_storages()

    def _reset_storages(self) -> None:
301
302
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
303
        self.best_score_list: List[_ListOfEvalResultTuples] = []
304
        self.cmp_op: List[Callable[[float, float], bool]] = []
305
        self.first_metric = ""
306
307
308
309
310
311
312

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

313
    def _is_train_set(self, dataset_name: str, env: CallbackEnv) -> bool:
314
315
316
        """Check, by name, if a given Dataset is the training data."""
        # for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
        # and those metrics are considered for early stopping
317
        if _is_using_cv(env) and dataset_name == "train":
318
319
320
            return True

        # for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
321
        if isinstance(env.model, Booster) and dataset_name == env.model._train_data_name:
322
323
324
            return True

        return False
325

326
    def _init(self, env: CallbackEnv) -> None:
327
        if env.evaluation_result_list is None or env.evaluation_result_list == []:
328
            raise ValueError("For early stopping, at least one dataset and eval metric is required for evaluation")
329

330
        is_dart = any(env.params.get(alias, "") == "dart" for alias in _ConfigAliases.get("boosting"))
331
332
        if is_dart:
            self.enabled = False
333
            _log_warning("Early stopping is not available in dart mode")
334
            return
wxchan's avatar
wxchan committed
335

336
337
338
        # get details of the first dataset
        first_dataset_name, first_metric_name, *_ = env.evaluation_result_list[0]

339
340
        # validation sets are guaranteed to not be identical to the training data in cv()
        if isinstance(env.model, Booster):
341
            only_train_set = len(env.evaluation_result_list) == 1 and self._is_train_set(
342
                dataset_name=first_dataset_name,
343
                env=env,
344
345
346
            )
            if only_train_set:
                self.enabled = False
347
                _log_warning("Only training set found, disabling early stopping.")
348
                return
349

350
351
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
352

353
        self._reset_storages()
354

355
        n_metrics = len({m[1] for m in env.evaluation_result_list})
356
        n_datasets = len(env.evaluation_result_list) // n_metrics
357
358
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
359
                raise ValueError("Values for early stopping min_delta must be non-negative.")
360
361
            if len(self.min_delta) == 0:
                if self.verbose:
362
                    _log_info("Disabling min_delta for early stopping.")
363
                deltas = [0.0] * n_datasets * n_metrics
364
365
            elif len(self.min_delta) == 1:
                if self.verbose:
366
                    _log_info(f"Using {self.min_delta[0]} as min_delta for all metrics.")
367
                deltas = self.min_delta * n_datasets * n_metrics
368
            else:
369
                if len(self.min_delta) != n_metrics:
370
                    raise ValueError("Must provide a single value for min_delta or as many as metrics.")
371
                if self.first_metric_only and self.verbose:
372
                    _log_info(f"Using only {self.min_delta[0]} as early stopping min_delta.")
373
                deltas = self.min_delta * n_datasets
374
        else:
375
            if self.min_delta < 0:
376
                raise ValueError("Early stopping min_delta must be non-negative.")
377
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
378
                _log_info(f"Using {self.min_delta} as min_delta for all metrics.")
379
            deltas = [self.min_delta] * n_datasets * n_metrics
380

381
        self.first_metric = first_metric_name
382
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
383
            self.best_iter.append(0)
384
            if eval_ret[3]:  # greater is better
385
                self.best_score.append(float("-inf"))
386
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
387
            else:
388
                self.best_score.append(float("inf"))
389
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
390

391
    def _final_iteration_check(self, *, env: CallbackEnv, metric_name: str, i: int) -> None:
392
        if env.iteration == env.end_iteration - 1:
393
            if self.verbose:
394
395
                best_score_str = "\t".join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
                _log_info(
396
                    f"Did not meet early stopping. Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}"
397
                )
398
                if self.first_metric_only:
399
                    _log_info(f"Evaluated only: {metric_name}")
400
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
401

402
    def __call__(self, env: CallbackEnv) -> None:
403
        if env.iteration == env.begin_iteration:
404
405
            self._init(env)
        if not self.enabled:
406
            return
407
408
409
410
411
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "early_stopping() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
412
        # self.best_score_list is initialized to an empty list
413
        first_time_updating_best_score_list = self.best_score_list == []
414
        for i in range(len(env.evaluation_result_list)):
415
416
417
            dataset_name, metric_name, metric_value, *_ = env.evaluation_result_list[i]
            if first_time_updating_best_score_list or self.cmp_op[i](metric_value, self.best_score[i]):
                self.best_score[i] = metric_value
418
                self.best_iter[i] = env.iteration
419
420
421
422
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
423
            if self.first_metric_only and self.first_metric != metric_name:
424
                continue  # use only the first metric for early stopping
425
            if self._is_train_set(
426
                dataset_name=dataset_name,
427
                env=env,
428
            ):
429
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
430
431
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
432
433
434
                    eval_result_str = "\t".join(
                        [_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]]
                    )
435
436
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
437
                        _log_info(f"Evaluated only: {metric_name}")
438
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
439
            self._final_iteration_check(env=env, metric_name=metric_name, i=i)
440
441


442
443
444
445
446
447
448
449
450
451
452
453
def _should_enable_early_stopping(stopping_rounds: Any) -> bool:
    """Check if early stopping should be activated.

    This function will evaluate to True if the early stopping callback should be
    activated (i.e. stopping_rounds > 0).  It also provides an informative error if the
    type is not int.
    """
    if not isinstance(stopping_rounds, int):
        raise TypeError(f"early_stopping_round should be an integer. Got '{type(stopping_rounds).__name__}'")
    return stopping_rounds > 0


454
455
456
457
458
459
def early_stopping(
    stopping_rounds: int,
    first_metric_only: bool = False,
    verbose: bool = True,
    min_delta: Union[float, List[float]] = 0.0,
) -> _EarlyStoppingCallback:
460
461
462
463
464
465
466
467
468
469
470
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

471
472
473
474
475
    .. note::

        If using ``boosting_type="dart"``, this callback has no effect and early stopping
        will not be performed.

476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

491
492
        .. versionadded:: 4.0.0

493
494
495
496
497
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
498
499
500
501
502
503
    return _EarlyStoppingCallback(
        stopping_rounds=stopping_rounds,
        first_metric_only=first_metric_only,
        verbose=verbose,
        min_delta=min_delta,
    )