"vscode:/vscode.git/clone" did not exist on "3ec345e2a2849ea25b5d7627d53afe050557cfbc"
callback.py 17.4 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3
4
from collections import OrderedDict
from dataclasses import dataclass
5
from functools import partial
6
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
wxchan's avatar
wxchan committed
7

8
9
from .basic import (Booster, _ConfigAliases, _LGBM_BoosterEvalMethodResultType,
                    _LGBM_BoosterEvalMethodResultWithStandardDeviationType, _log_info, _log_warning)
10
11
12

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
13

14
15
16
17
18
19
20
__all__ = [
    'early_stopping',
    'log_evaluation',
    'record_evaluation',
    'reset_parameter',
]

21
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
22
_EvalResultTuple = Union[
23
    _LGBM_BoosterEvalMethodResultType,
24
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType
25
26
]
_ListOfEvalResultTuples = Union[
27
    List[_LGBM_BoosterEvalMethodResultType],
28
    List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType]
29
30
]

wxchan's avatar
wxchan committed
31

wxchan's avatar
wxchan committed
32
class EarlyStopException(Exception):
33
    """Exception of early stopping."""
34

35
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
36
37
38
39
40
41
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
42
43
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
44
        """
45
        super().__init__()
wxchan's avatar
wxchan committed
46
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
47
        self.best_score = best_score
wxchan's avatar
wxchan committed
48

wxchan's avatar
wxchan committed
49

wxchan's avatar
wxchan committed
50
# Callback environment used by callbacks
51
52
53
54
55
56
57
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
58
    evaluation_result_list: Optional[_ListOfEvalResultTuples]
wxchan's avatar
wxchan committed
59

wxchan's avatar
wxchan committed
60

61
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
62
    """Format metric string."""
wxchan's avatar
wxchan committed
63
    if len(value) == 4:
64
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
65
66
    elif len(value) == 5:
        if show_stdv:
67
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
68
        else:
69
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
70
    else:
71
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
72
73


74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
            result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f'[{env.iteration + 1}]\t{result}')


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
91
92
    """Create a callback that logs the evaluation results.

93
94
95
96
97
98
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
99
100
101

    Parameters
    ----------
102
    period : int, optional (default=1)
103
104
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
105
    show_stdv : bool, optional (default=True)
106
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
107
108
109

    Returns
    -------
110
    callback : _LogEvaluationCallback
111
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
112
    """
113
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
114
115


116
117
118
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

119
    def __init__(self, eval_result: _EvalResultDict) -> None:
120
121
122
123
124
125
126
127
128
129
130
131
132
133
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
            raise TypeError('eval_result should be a dictionary')
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
134
            self.eval_result.setdefault(data_name, OrderedDict())
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
                self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
150
151
                res_mean = item[2]
                res_stdv = item[4]
152
153
154
155
                self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)


156
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
157
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
158
159
160
161

    Parameters
    ----------
    eval_result : dict
162
163
164
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
184
185
    Returns
    -------
186
    callback : _RecordEvaluationCallback
187
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
188
    """
189
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
190
191


192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
                raise ValueError("Only list and callable values are supported "
                                 "as a mapping from boosting round index to new parameter value.")
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)


220
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
221
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
222

Nikita Titov's avatar
Nikita Titov committed
223
224
225
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
226
227
228

    Parameters
    ----------
229
    **kwargs : value should be list or callable
230
        List of parameters for each boosting round
231
        or a callable that calculates the parameter in terms of
232
233
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
234
        If callable func, parameter = func(current_round).
235

wxchan's avatar
wxchan committed
236
237
    Returns
    -------
238
    callback : _ResetParameterCallback
239
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
240
    """
241
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
242
243


244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
        min_delta: Union[float, List[float]] = 0.0
    ) -> None:
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self.enabled = True
        self._reset_storages()

    def _reset_storages(self) -> None:
266
267
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
268
        self.best_score_list: List[_ListOfEvalResultTuples] = []
269
        self.cmp_op: List[Callable[[float, float], bool]] = []
270
271
272
273
274
275
276
277
        self.first_metric = ''

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

278
279
280
    def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool:
        return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name

281
    def _init(self, env: CallbackEnv) -> None:
282
283
284
285
286
287
288
289
290
        is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting"))
        only_train_set = (
            len(env.evaluation_result_list) == 1
            and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                train_name=env.model._train_data_name)
        )
        self.enabled = not is_dart and not only_train_set
291
        if not self.enabled:
292
293
294
295
            if is_dart:
                _log_warning('Early stopping is not available in dart mode')
            elif only_train_set:
                _log_warning('Only training set found, disabling early stopping.')
296
            return
297
        if not env.evaluation_result_list:
298
299
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
300

301
        if self.stopping_rounds <= 0:
302
303
            raise ValueError("stopping_rounds should be greater than zero.")

304
305
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
306

307
        self._reset_storages()
308

309
        n_metrics = len({m[1] for m in env.evaluation_result_list})
310
        n_datasets = len(env.evaluation_result_list) // n_metrics
311
312
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
313
                raise ValueError('Values for early stopping min_delta must be non-negative.')
314
315
            if len(self.min_delta) == 0:
                if self.verbose:
316
317
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
318
319
320
321
            elif len(self.min_delta) == 1:
                if self.verbose:
                    _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
                deltas = self.min_delta * n_datasets * n_metrics
322
            else:
323
                if len(self.min_delta) != n_metrics:
324
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
325
326
327
                if self.first_metric_only and self.verbose:
                    _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
                deltas = self.min_delta * n_datasets
328
        else:
329
            if self.min_delta < 0:
330
                raise ValueError('Early stopping min_delta must be non-negative.')
331
332
333
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
                _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
            deltas = [self.min_delta] * n_datasets * n_metrics
334

335
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
336
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
337
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
338
            self.best_iter.append(0)
339
            if eval_ret[3]:  # greater is better
340
341
                self.best_score.append(float('-inf'))
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
342
            else:
343
344
                self.best_score.append(float('inf'))
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
345

346
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
347
        if env.iteration == env.end_iteration - 1:
348
            if self.verbose:
349
                best_score_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
350
                _log_info('Did not meet early stopping. '
351
352
                          f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
                if self.first_metric_only:
353
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
354
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
355

356
    def __call__(self, env: CallbackEnv) -> None:
357
        if env.iteration == env.begin_iteration:
358
359
            self._init(env)
        if not self.enabled:
360
            return
361
362
        # self.best_score_list is initialized to an empty list
        first_time_updating_best_score_list = (self.best_score_list == [])
363
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
364
            score = env.evaluation_result_list[i][2]
365
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
366
367
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
368
369
370
371
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
372
373
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
374
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
375
                continue  # use only the first metric for early stopping
376
            if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name):
377
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
378
379
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
380
                    eval_result_str = '\t'.join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
381
382
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
383
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

415
416
        .. versionadded:: 4.0.0

417
418
419
420
421
422
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
    return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)