"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "cdfe97f5d799013352b405bb31183f28eff1d1ca"
callback.py 19.5 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3
4
from collections import OrderedDict
from dataclasses import dataclass
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
wxchan's avatar
wxchan committed
7

8
9
from .basic import (Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType,
                    _LGBM_BoosterEvalMethodResultWithStandardDeviationType, _log_info, _log_warning)
10
11
12

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
13

14
__all__ = [
15
    'EarlyStopException',
16
17
18
19
20
21
    'early_stopping',
    'log_evaluation',
    'record_evaluation',
    'reset_parameter',
]

22
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
23
_EvalResultTuple = Union[
24
    _LGBM_BoosterEvalMethodResultType,
25
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType
26
27
]
_ListOfEvalResultTuples = Union[
28
    List[_LGBM_BoosterEvalMethodResultType],
29
    List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]
30
31
]

wxchan's avatar
wxchan committed
32

wxchan's avatar
wxchan committed
33
class EarlyStopException(Exception):
34
35
36
37
38
    """Exception of early stopping.

    Raise this from a callback passed in via keyword argument ``callbacks``
    in ``cv()`` or ``train()`` to trigger early stopping.
    """
39

40
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
41
42
43
44
45
46
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
47
            0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
48
49
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
50
        """
51
        super().__init__()
wxchan's avatar
wxchan committed
52
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
53
        self.best_score = best_score
wxchan's avatar
wxchan committed
54

wxchan's avatar
wxchan committed
55

wxchan's avatar
wxchan committed
56
# Callback environment used by callbacks
57
58
59
60
61
62
63
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
64
    evaluation_result_list: Optional[_ListOfEvalResultTuples]
wxchan's avatar
wxchan committed
65

wxchan's avatar
wxchan committed
66

67
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
68
    """Format metric string."""
wxchan's avatar
wxchan committed
69
    if len(value) == 4:
70
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
71
72
    elif len(value) == 5:
        if show_stdv:
73
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
74
        else:
75
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
76
    else:
77
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
78
79


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
97
98
    """Create a callback that logs the evaluation results.

99
100
101
102
103
104
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
105
106
107

    Parameters
    ----------
108
    period : int, optional (default=1)
109
110
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
111
    show_stdv : bool, optional (default=True)
112
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
113
114
115

    Returns
    -------
116
    callback : _LogEvaluationCallback
117
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
118
    """
119
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
120
121


122
123
124
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

125
    def __init__(self, eval_result: _EvalResultDict) -> None:
126
127
128
129
130
131
132
133
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
134
135
136
137
138
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
139
140
141
142
143
144
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
145
            self.eval_result.setdefault(data_name, OrderedDict())
146
147
148
149
150
151
152
153
154
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
155
156
157
158
159
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
160
161
162
163
164
165
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
166
                res_mean = item[2]
167
                res_stdv = item[4]  # type: ignore[misc]
168
169
170
171
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


172
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
173
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
174
175
176
177

    Parameters
    ----------
    eval_result : dict
178
179
180
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
181

182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
200
201
    Returns
    -------
202
    callback : _RecordEvaluationCallback
203
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
204
    """
205
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
206
207


208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
                raise ValueError("Only list and callable values are supported "
                                 "as a mapping from boosting round index to new parameter value.")
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
232
233
234
235
236
237
            if isinstance(env.model, Booster):
                env.model.reset_parameter(new_parameters)
            else:
                # CVBooster holds a list of Booster objects, each needs to be updated
                for booster in env.model.boosters:
                    booster.reset_parameter(new_parameters)
238
239
240
            env.params.update(new_parameters)


241
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
242
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
243

Nikita Titov's avatar
Nikita Titov committed
244
245
246
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
247
248
249

    Parameters
    ----------
250
    **kwargs : value should be list or callable
251
        List of parameters for each boosting round
252
        or a callable that calculates the parameter in terms of
253
254
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
255
        If callable func, parameter = func(current_round).
256

wxchan's avatar
wxchan committed
257
258
    Returns
    -------
259
    callback : _ResetParameterCallback
260
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
261
    """
262
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
263
264


265
266
267
268
269
270
271
272
273
274
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
275
276
277
278

        if not isinstance(stopping_rounds, int) or stopping_rounds <= 0:
            raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}")

279
280
281
282
283
284
285
286
287
288
289
290
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
291
292
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
293
        self.best_score_list: List[_ListOfEvalResultTuples] = []
294
        self.cmp_op: List[Callable[[float, float], bool]] = []
295
296
297
298
299
300
301
302
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

303
304
305
306
307
308
309
310
311
312
313
314
    def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
        """Check, by name, if a given Dataset is the training data."""
        # for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
        # and those metrics are considered for early stopping
        if ds_name == "cv_agg" and eval_name == "train":
            return True

        # for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
        if isinstance(env.model, Booster) and ds_name == env.model._train_data_name:
            return True

        return False
315

316
    def _init(self, env: CallbackEnv) -> None:
317
318
319
320
        if env.evaluation_result_list is None or env.evaluation_result_list == []:
            raise ValueError(
                "For early stopping, at least one dataset and eval metric is required for evaluation"
            )
321

322
        is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
323
324
325
        if is_dart:
            self.enabled = False
            _log_warning('Early stopping is not available in dart mode')
326
            return
wxchan's avatar
wxchan committed
327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
        # validation sets are guaranteed to not be identical to the training data in cv()
        if isinstance(env.model, Booster):
            only_train_set = (
                len(env.evaluation_result_list) == 1
                and self._is_train_set(
                    ds_name=env.evaluation_result_list[0][0],
                    eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                    env=env
                )
            )
            if only_train_set:
                self.enabled = False
                _log_warning('Only training set found, disabling early stopping.')
                return
342

343
344
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
345

346
        self._reset_storages()
347

348
        n_metrics = len({m[1] for m in env.evaluation_result_list})
349
        n_datasets = len(env.evaluation_result_list) // n_metrics
350
351
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
352
                raise ValueError('Values for early stopping min_delta must be non-negative.')
353
354
            if len(self.min_delta) == 0:
                if self.verbose:
355
356
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
357
358
359
360
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
361
            else:
362
                if len(self.min_delta) != n_metrics:
363
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
364
365
366
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
367
        else:
368
            if self.min_delta < 0:
369
                raise ValueError('Early stopping min_delta must be non-negative.')
370
371
372
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
373

374
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
375
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
376
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
377
            self.best_iter.append(0)
378
            if eval_ret[3]:  # greater is better
379
380
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
381
            else:
382
383
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
384

385
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
386
        if env.iteration == env.end_iteration - 1:
387
            if self.verbose:
388
                best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
389
                _log_info('Did not meet early stopping. '
390
391
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
392
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
393
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
394

395
    def __call__(self, env: CallbackEnv) -> None:
396
        if env.iteration == env.begin_iteration:
397
398
            self._init(env)
        if not self.enabled:
399
            return
400
401
402
403
404
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "early_stopping() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
405
406
        # self.best_score_list is initialized to an empty list
        first_time_updating_best_score_list = (self.best_score_list == [])
407
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
408
            score = env.evaluation_result_list[i][2]
409
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
410
411
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
412
413
414
415
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
416
417
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
418
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
419
                continue  # use only the first metric for early stopping
420
421
422
423
424
            if self._is_train_set(
                ds_name=env.evaluation_result_list[i][0],
                eval_name=eval_name_splitted[0],
                env=env
            ):
425
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
426
427
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
428
                    eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
429
430
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
431
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

463
464
        .. versionadded:: 4.0.0

465
466
467
468
469
470
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)