"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "7d582dd8b8528c39169b1cc3150fce0149637109"
callback.py 14.2 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
from functools import partial
5
from typing import Any, Callable, Dict, List, Tuple, Union
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
8

9
10
11
12
13
_EvalResultTuple = Union[
    List[Tuple[str, str, float, bool]],
    List[Tuple[str, str, float, bool, float]]
]

wxchan's avatar
wxchan committed
14

15
16
17
18
19
20
21
22
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score > best_score + delta


def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
    return curr_score < best_score - delta


wxchan's avatar
wxchan committed
23
class EarlyStopException(Exception):
24
    """Exception of early stopping."""
25

26
    def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
27
28
29
30
31
32
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
33
34
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
35
        """
36
        super().__init__()
wxchan's avatar
wxchan committed
37
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
38
        self.best_score = best_score
wxchan's avatar
wxchan committed
39

wxchan's avatar
wxchan committed
40

wxchan's avatar
wxchan committed
41
42
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
43
    "CallbackEnv",
wxchan's avatar
wxchan committed
44
    ["model",
45
     "params",
wxchan's avatar
wxchan committed
46
47
48
49
50
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
51

52
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
53
    """Format metric string."""
wxchan's avatar
wxchan committed
54
    if len(value) == 4:
55
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
56
57
    elif len(value) == 5:
        if show_stdv:
58
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"
wxchan's avatar
wxchan committed
59
        else:
60
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
61
    else:
62
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
63
64


65
66
67
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
    """Create a callback that logs the evaluation results.

68
69
70
71
72
73
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
74
75
76

    Parameters
    ----------
77
    period : int, optional (default=1)
78
79
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
80
    show_stdv : bool, optional (default=True)
81
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
82
83
84

    Returns
    -------
85
    callback : callable
86
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
87
    """
88
    def _callback(env: CallbackEnv) -> None:
89
90
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
91
            _log_info(f'[{env.iteration + 1}]\t{result}')
92
    _callback.order = 10  # type: ignore
93
    return _callback
wxchan's avatar
wxchan committed
94
95


96
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
97
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
98
99
100
101

    Parameters
    ----------
    eval_result : dict
102
103
104
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
105

106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
124
125
    Returns
    -------
126
    callback : callable
127
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
128
129
    """
    if not isinstance(eval_result, dict):
130
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
131

132
    def _init(env: CallbackEnv) -> None:
133
        eval_result.clear()
134
135
136
137
138
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
139
            eval_result.setdefault(data_name, collections.OrderedDict())
140
141
142
143
144
            if len(item) == 4:
                eval_result[data_name].setdefault(eval_name, [])
            else:
                eval_result[data_name].setdefault(f'{eval_name}-mean', [])
                eval_result[data_name].setdefault(f'{eval_name}-stdv', [])
wxchan's avatar
wxchan committed
145

146
    def _callback(env: CallbackEnv) -> None:
147
        if env.iteration == env.begin_iteration:
148
            _init(env)
149
150
151
152
153
154
155
156
157
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
                res_mean, res_stdv = item[2], item[4]
                eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
                eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
158
    _callback.order = 20  # type: ignore
159
    return _callback
wxchan's avatar
wxchan committed
160
161


162
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
163
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
164

Nikita Titov's avatar
Nikita Titov committed
165
166
167
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
168
169
170

    Parameters
    ----------
171
    **kwargs : value should be list or callable
172
        List of parameters for each boosting round
173
        or a callable that calculates the parameter in terms of
174
175
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
176
        If callable func, parameter = func(current_round).
177

wxchan's avatar
wxchan committed
178
179
    Returns
    -------
180
    callback : callable
181
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
182
    """
183
    def _callback(env: CallbackEnv) -> None:
184
        new_parameters = {}
185
186
187
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
188
                    raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
189
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
190
            else:
191
192
193
194
195
196
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
197
198
    _callback.before_iteration = True  # type: ignore
    _callback.order = 10  # type: ignore
199
    return _callback
wxchan's avatar
wxchan committed
200
201


202
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable:
wxchan's avatar
wxchan committed
203
    """Create a callback that activates early stopping.
204

wxchan's avatar
wxchan committed
205
    Activates early stopping.
206
    The model will train until the validation score doesn't improve by at least ``min_delta``.
207
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
208
    to continue training.
209
    Requires at least one validation data and one metric.
210
    If there's more than one, will check all of them. But the training data is ignored anyway.
211
    To check only the first metric set ``first_metric_only`` to True.
212
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
wxchan's avatar
wxchan committed
213
214
215
216

    Parameters
    ----------
    stopping_rounds : int
217
        The possible number of rounds without the trend occurrence.
218
    first_metric_only : bool, optional (default=False)
219
        Whether to use only the first metric for early stopping.
220
    verbose : bool, optional (default=True)
221
222
223
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
224
225
226
227
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.
wxchan's avatar
wxchan committed
228
229
230

    Returns
    -------
231
    callback : callable
232
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
233
    """
wxchan's avatar
wxchan committed
234
235
    best_score = []
    best_iter = []
236
    best_score_list: list = []
wxchan's avatar
wxchan committed
237
    cmp_op = []
238
239
    enabled = True
    first_metric = ''
wxchan's avatar
wxchan committed
240

241
    def _init(env: CallbackEnv) -> None:
242
243
244
245
246
247
248
249
250
        nonlocal best_score
        nonlocal best_iter
        nonlocal best_score_list
        nonlocal cmp_op
        nonlocal enabled
        nonlocal first_metric
        enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                          in _ConfigAliases.get("boosting"))
        if not enabled:
251
            _log_warning('Early stopping is not available in dart mode')
252
            return
253
        if not env.evaluation_result_list:
254
255
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
256

257
258
259
        if stopping_rounds <= 0:
            raise ValueError("stopping_rounds should be greater than zero.")

wxchan's avatar
wxchan committed
260
        if verbose:
261
            _log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")
wxchan's avatar
wxchan committed
262

263
264
265
266
267
268
269
        # reset storages
        best_score = []
        best_iter = []
        best_score_list = []
        cmp_op = []
        first_metric = ''

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        n_metrics = len(set(m[1] for m in env.evaluation_result_list))
        n_datasets = len(env.evaluation_result_list) // n_metrics
        if isinstance(min_delta, list):
            if not all(t >= 0 for t in min_delta):
                raise ValueError('Values for early stopping min_delta must be non-negative.')
            if len(min_delta) == 0:
                if verbose:
                    _log_info('Disabling min_delta for early stopping.')
                deltas = [0.0] * n_datasets * n_metrics
            elif len(min_delta) == 1:
                if verbose:
                    _log_info(f'Using {min_delta[0]} as min_delta for all metrics.')
                deltas = min_delta * n_datasets * n_metrics
            else:
                if len(min_delta) != n_metrics:
                    raise ValueError('Must provide a single value for min_delta or as many as metrics.')
                if first_metric_only and verbose:
                    _log_info(f'Using only {min_delta[0]} as early stopping min_delta.')
                deltas = min_delta * n_datasets
        else:
            if min_delta < 0:
                raise ValueError('Early stopping min_delta must be non-negative.')
            if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose:
                _log_info(f'Using {min_delta} as min_delta for all metrics.')
            deltas = [min_delta] * n_datasets * n_metrics

296
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
297
        first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
298
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
wxchan's avatar
wxchan committed
299
            best_iter.append(0)
wxchan's avatar
wxchan committed
300
            best_score_list.append(None)
301
            if eval_ret[3]:  # greater is better
wxchan's avatar
wxchan committed
302
                best_score.append(float('-inf'))
303
                cmp_op.append(partial(_gt_delta, delta=delta))
wxchan's avatar
wxchan committed
304
305
            else:
                best_score.append(float('inf'))
306
                cmp_op.append(partial(_lt_delta, delta=delta))
wxchan's avatar
wxchan committed
307

308
    def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
309
310
        nonlocal best_iter
        nonlocal best_score_list
311
312
        if env.iteration == env.end_iteration - 1:
            if verbose:
313
314
315
                best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                _log_info('Did not meet early stopping. '
                          f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}')
316
                if first_metric_only:
317
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
318
319
            raise EarlyStopException(best_iter[i], best_score_list[i])

320
    def _callback(env: CallbackEnv) -> None:
321
322
323
324
325
326
        nonlocal best_score
        nonlocal best_iter
        nonlocal best_score_list
        nonlocal cmp_op
        nonlocal enabled
        nonlocal first_metric
327
        if env.iteration == env.begin_iteration:
328
            _init(env)
329
        if not enabled:
330
            return
331
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
332
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
333
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
334
335
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
336
                best_score_list[i] = env.evaluation_result_list
337
338
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
339
            if first_metric_only and first_metric != eval_name_splitted[-1]:
340
341
342
343
344
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
345
346
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
347
348
                    eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]])
                    _log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}")
349
                    if first_metric_only:
350
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
wxchan's avatar
wxchan committed
351
                raise EarlyStopException(best_iter[i], best_score_list[i])
352
            _final_iteration_check(env, eval_name_splitted, i)
353
    _callback.order = 30  # type: ignore
354
    return _callback