callback.py 8.89 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
wxchan's avatar
wxchan committed
4
from operator import gt, lt
wxchan's avatar
wxchan committed
5

6
from .basic import _ConfigAliases, _log_info, _log_warning
wxchan's avatar
wxchan committed
7

wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9
class EarlyStopException(Exception):
10
    """Exception of early stopping."""
11

wxchan's avatar
wxchan committed
12
    def __init__(self, best_iteration, best_score):
13
14
15
16
17
18
19
20
21
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
22
        super().__init__()
wxchan's avatar
wxchan committed
23
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
24
        self.best_score = best_score
wxchan's avatar
wxchan committed
25

wxchan's avatar
wxchan committed
26

wxchan's avatar
wxchan committed
27
28
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
29
    "CallbackEnv",
wxchan's avatar
wxchan committed
30
    ["model",
31
     "params",
wxchan's avatar
wxchan committed
32
33
34
35
36
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
37

wxchan's avatar
wxchan committed
38
def _format_eval_result(value, show_stdv=True):
39
    """Format metric string."""
wxchan's avatar
wxchan committed
40
    if len(value) == 4:
41
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
42
43
    elif len(value) == 5:
        if show_stdv:
44
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
45
        else:
46
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
47
    else:
48
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
49
50
51


def print_evaluation(period=1, show_stdv=True):
52
    """Create a callback that prints the evaluation results.
wxchan's avatar
wxchan committed
53
54
55

    Parameters
    ----------
56
57
58
59
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
wxchan's avatar
wxchan committed
60
61
62
63

    Returns
    -------
    callback : function
64
        The callback that prints the evaluation results every ``period`` iteration(s).
wxchan's avatar
wxchan committed
65
    """
66
    def _callback(env):
67
68
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
69
            _log_info('[%d]\t%s' % (env.iteration + 1, result))
70
71
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
72
73
74


def record_evaluation(eval_result):
75
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
76
77
78
79
80
81
82
83
84

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
85
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
86
87
    """
    if not isinstance(eval_result, dict):
88
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
89
90
    eval_result.clear()

91
    def _init(env):
92
93
94
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
95

96
    def _callback(env):
97
        if not eval_result:
98
            _init(env)
wxchan's avatar
wxchan committed
99
100
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
101
102
    _callback.order = 20
    return _callback
wxchan's avatar
wxchan committed
103
104


105
def reset_parameter(**kwargs):
106
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
107

Nikita Titov's avatar
Nikita Titov committed
108
109
110
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
111
112
113

    Parameters
    ----------
114
    **kwargs : value should be list or function
115
        List of parameters for each boosting round
116
117
118
119
120
        or a customized function that calculates the parameter in terms of
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
        If function func, parameter = func(current_round).

wxchan's avatar
wxchan committed
121
122
123
    Returns
    -------
    callback : function
124
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
125
    """
126
    def _callback(env):
127
        new_parameters = {}
128
129
130
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
131
132
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'."
                                     .format(repr(key)))
133
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
134
            else:
135
136
137
138
139
140
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
141
142
143
    _callback.before_iteration = True
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
144
145


146
def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
wxchan's avatar
wxchan committed
147
    """Create a callback that activates early stopping.
148

wxchan's avatar
wxchan committed
149
    Activates early stopping.
150
151
152
    The model will train until the validation score stops improving.
    Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
    to continue training.
153
    Requires at least one validation data and one metric.
154
    If there's more than one, will check all of them. But the training data is ignored anyway.
155
    To check only the first metric set ``first_metric_only`` to True.
wxchan's avatar
wxchan committed
156
157
158
159

    Parameters
    ----------
    stopping_rounds : int
160
       The possible number of rounds without the trend occurrence.
161
162
    first_metric_only : bool, optional (default=False)
       Whether to use only the first metric for early stopping.
163
164
    verbose : bool, optional (default=True)
        Whether to print message with early stopping information.
wxchan's avatar
wxchan committed
165
166
167
168

    Returns
    -------
    callback : function
169
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
170
    """
wxchan's avatar
wxchan committed
171
172
    best_score = []
    best_iter = []
wxchan's avatar
wxchan committed
173
    best_score_list = []
wxchan's avatar
wxchan committed
174
    cmp_op = []
175
    enabled = [True]
176
    first_metric = ['']
wxchan's avatar
wxchan committed
177

178
    def _init(env):
179
180
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
181
        if not enabled[0]:
182
            _log_warning('Early stopping is not available in dart mode')
183
            return
184
        if not env.evaluation_result_list:
185
186
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
187
188

        if verbose:
189
            _log_info("Training until validation scores don't improve for {} rounds".format(stopping_rounds))
wxchan's avatar
wxchan committed
190

191
192
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
wxchan's avatar
wxchan committed
193
194
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
195
            best_score_list.append(None)
wxchan's avatar
wxchan committed
196
197
198
199
200
201
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
202

203
204
205
    def _final_iteration_check(env, eval_name_splitted, i):
        if env.iteration == env.end_iteration - 1:
            if verbose:
206
                _log_info('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
207
208
                    best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                if first_metric_only:
209
                    _log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
210
211
            raise EarlyStopException(best_iter[i], best_score_list[i])

212
    def _callback(env):
wxchan's avatar
wxchan committed
213
        if not cmp_op:
214
            _init(env)
215
216
        if not enabled[0]:
            return
217
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
218
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
219
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
220
221
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
222
                best_score_list[i] = env.evaluation_result_list
223
224
225
226
227
228
229
230
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
231
232
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
233
                    _log_info('Early stopping, best iteration is:\n[%d]\t%s' % (
wxchan's avatar
wxchan committed
234
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
235
                    if first_metric_only:
236
                        _log_info("Evaluated only: {}".format(eval_name_splitted[-1]))
wxchan's avatar
wxchan committed
237
                raise EarlyStopException(best_iter[i], best_score_list[i])
238
            _final_iteration_check(env, eval_name_splitted, i)
239
240
    _callback.order = 30
    return _callback