callback.py 9.18 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
3
"""Callbacks library."""
wxchan's avatar
wxchan committed
4
from __future__ import absolute_import
5

wxchan's avatar
wxchan committed
6
import collections
7
import warnings
wxchan's avatar
wxchan committed
8
from operator import gt, lt
wxchan's avatar
wxchan committed
9

10
from .basic import _ConfigAliases
wxchan's avatar
wxchan committed
11
12
from .compat import range_

wxchan's avatar
wxchan committed
13

wxchan's avatar
wxchan committed
14
class EarlyStopException(Exception):
15
    """Exception of early stopping."""
16

wxchan's avatar
wxchan committed
17
    def __init__(self, best_iteration, best_score):
18
19
20
21
22
23
24
25
26
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
wxchan's avatar
wxchan committed
27
28
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
29
        self.best_score = best_score
wxchan's avatar
wxchan committed
30

wxchan's avatar
wxchan committed
31

wxchan's avatar
wxchan committed
32
33
34
35
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
36
     "params",
wxchan's avatar
wxchan committed
37
38
39
40
41
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
42

wxchan's avatar
wxchan committed
43
def _format_eval_result(value, show_stdv=True):
44
    """Format metric string."""
wxchan's avatar
wxchan committed
45
    if len(value) == 4:
46
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
47
48
    elif len(value) == 5:
        if show_stdv:
49
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
50
        else:
51
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
52
    else:
53
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
54
55
56


def print_evaluation(period=1, show_stdv=True):
57
    """Create a callback that prints the evaluation results.
wxchan's avatar
wxchan committed
58
59
60

    Parameters
    ----------
61
62
63
64
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
wxchan's avatar
wxchan committed
65
66
67
68

    Returns
    -------
    callback : function
69
        The callback that prints the evaluation results every ``period`` iteration(s).
wxchan's avatar
wxchan committed
70
    """
71
    def _callback(env):
72
73
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
74
            print('[%d]\t%s' % (env.iteration + 1, result))
75
76
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
77
78
79


def record_evaluation(eval_result):
80
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
81
82
83
84
85
86
87
88
89

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
90
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
91
92
    """
    if not isinstance(eval_result, dict):
93
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
94
95
    eval_result.clear()

96
    def _init(env):
97
98
99
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
100

101
    def _callback(env):
102
        if not eval_result:
103
            _init(env)
wxchan's avatar
wxchan committed
104
105
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
106
107
    _callback.order = 20
    return _callback
wxchan's avatar
wxchan committed
108
109


110
def reset_parameter(**kwargs):
111
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
112

Nikita Titov's avatar
Nikita Titov committed
113
114
115
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
116
117
118

    Parameters
    ----------
119
    **kwargs : value should be list or function
120
        List of parameters for each boosting round
121
122
123
124
125
        or a customized function that calculates the parameter in terms of
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
        If function func, parameter = func(current_round).

wxchan's avatar
wxchan committed
126
127
128
    Returns
    -------
    callback : function
129
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
130
    """
131
    def _callback(env):
132
        new_parameters = {}
133
        for key, value in kwargs.items():
134
            if key in _ConfigAliases.get("num_class", "boosting", "metric"):
135
                raise RuntimeError("Cannot reset {} during training".format(repr(key)))
136
137
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
138
139
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'."
                                     .format(repr(key)))
140
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
141
            else:
142
143
144
145
146
147
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
148
149
150
    _callback.before_iteration = True
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
151
152


153
def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
wxchan's avatar
wxchan committed
154
    """Create a callback that activates early stopping.
155

wxchan's avatar
wxchan committed
156
    Activates early stopping.
157
158
159
    The model will train until the validation score stops improving.
    Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
    to continue training.
160
    Requires at least one validation data and one metric.
161
    If there's more than one, will check all of them. But the training data is ignored anyway.
162
    To check only the first metric set ``first_metric_only`` to True.
wxchan's avatar
wxchan committed
163
164
165
166

    Parameters
    ----------
    stopping_rounds : int
167
       The possible number of rounds without the trend occurrence.
168
169
    first_metric_only : bool, optional (default=False)
       Whether to use only the first metric for early stopping.
170
171
    verbose : bool, optional (default=True)
        Whether to print message with early stopping information.
wxchan's avatar
wxchan committed
172
173
174
175

    Returns
    -------
    callback : function
176
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
177
    """
wxchan's avatar
wxchan committed
178
179
    best_score = []
    best_iter = []
wxchan's avatar
wxchan committed
180
    best_score_list = []
wxchan's avatar
wxchan committed
181
    cmp_op = []
182
    enabled = [True]
183
    first_metric = ['']
wxchan's avatar
wxchan committed
184

185
    def _init(env):
186
187
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
188
189
190
        if not enabled[0]:
            warnings.warn('Early stopping is not available in dart mode')
            return
191
        if not env.evaluation_result_list:
192
193
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
194
195

        if verbose:
196
            msg = "Training until validation scores don't improve for {} rounds"
wxchan's avatar
wxchan committed
197
198
            print(msg.format(stopping_rounds))

199
200
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
wxchan's avatar
wxchan committed
201
202
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
203
            best_score_list.append(None)
wxchan's avatar
wxchan committed
204
205
206
207
208
209
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
210

211
212
213
214
215
216
217
218
219
    def _final_iteration_check(env, eval_name_splitted, i):
        if env.iteration == env.end_iteration - 1:
            if verbose:
                print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
                    best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                if first_metric_only:
                    print("Evaluated only: {}".format(eval_name_splitted[-1]))
            raise EarlyStopException(best_iter[i], best_score_list[i])

220
    def _callback(env):
wxchan's avatar
wxchan committed
221
        if not cmp_op:
222
            _init(env)
223
224
        if not enabled[0]:
            return
wxchan's avatar
wxchan committed
225
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
226
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
227
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
228
229
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
230
                best_score_list[i] = env.evaluation_result_list
231
232
233
234
235
236
237
238
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
239
240
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
wxchan's avatar
wxchan committed
241
242
                    print('Early stopping, best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
243
244
                    if first_metric_only:
                        print("Evaluated only: {}".format(eval_name_splitted[-1]))
wxchan's avatar
wxchan committed
245
                raise EarlyStopException(best_iter[i], best_score_list[i])
246
            _final_iteration_check(env, eval_name_splitted, i)
247
248
    _callback.order = 30
    return _callback