callback.py 9.45 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
3
"""Callbacks library."""
wxchan's avatar
wxchan committed
4
from __future__ import absolute_import
5

wxchan's avatar
wxchan committed
6
import collections
7
import warnings
wxchan's avatar
wxchan committed
8
from operator import gt, lt
wxchan's avatar
wxchan committed
9

wxchan's avatar
wxchan committed
10
11
from .compat import range_

wxchan's avatar
wxchan committed
12

wxchan's avatar
wxchan committed
13
class EarlyStopException(Exception):
14
    """Exception of early stopping."""
15

wxchan's avatar
wxchan committed
16
    def __init__(self, best_iteration, best_score):
17
18
19
20
21
22
23
24
25
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
wxchan's avatar
wxchan committed
26
27
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
28
        self.best_score = best_score
wxchan's avatar
wxchan committed
29

wxchan's avatar
wxchan committed
30

wxchan's avatar
wxchan committed
31
32
33
34
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
35
     "params",
wxchan's avatar
wxchan committed
36
37
38
39
40
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
41

wxchan's avatar
wxchan committed
42
def _format_eval_result(value, show_stdv=True):
43
    """Format metric string."""
wxchan's avatar
wxchan committed
44
    if len(value) == 4:
45
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
46
47
    elif len(value) == 5:
        if show_stdv:
48
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
49
        else:
50
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
51
    else:
52
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
53
54
55


def print_evaluation(period=1, show_stdv=True):
56
    """Create a callback that prints the evaluation results.
wxchan's avatar
wxchan committed
57
58
59

    Parameters
    ----------
60
61
62
63
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
wxchan's avatar
wxchan committed
64
65
66
67

    Returns
    -------
    callback : function
68
        The callback that prints the evaluation results every ``period`` iteration(s).
wxchan's avatar
wxchan committed
69
    """
70
    def _callback(env):
71
72
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
73
            print('[%d]\t%s' % (env.iteration + 1, result))
74
75
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
76
77
78


def record_evaluation(eval_result):
79
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
80
81
82
83
84
85
86
87
88

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
89
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
90
91
    """
    if not isinstance(eval_result, dict):
92
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
93
94
    eval_result.clear()

95
    def _init(env):
96
97
98
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
99

100
    def _callback(env):
101
        if not eval_result:
102
            _init(env)
wxchan's avatar
wxchan committed
103
104
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
105
106
    _callback.order = 20
    return _callback
wxchan's avatar
wxchan committed
107
108


109
def reset_parameter(**kwargs):
110
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
111

Nikita Titov's avatar
Nikita Titov committed
112
113
114
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
115
116
117

    Parameters
    ----------
118
    **kwargs : value should be list or function
119
        List of parameters for each boosting round
120
121
122
123
124
        or a customized function that calculates the parameter in terms of
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
        If function func, parameter = func(current_round).

wxchan's avatar
wxchan committed
125
126
127
    Returns
    -------
    callback : function
128
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
129
    """
130
    def _callback(env):
131
        new_parameters = {}
132
        for key, value in kwargs.items():
133
134
135
            if key in ['num_class', 'num_classes',
                       'boosting', 'boost', 'boosting_type',
                       'metric', 'metrics', 'metric_types']:
136
                raise RuntimeError("Cannot reset {} during training".format(repr(key)))
137
138
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
139
140
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'."
                                     .format(repr(key)))
141
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
142
            else:
143
144
145
146
147
148
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
149
150
151
    _callback.before_iteration = True
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
152
153


154
def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
wxchan's avatar
wxchan committed
155
    """Create a callback that activates early stopping.
156

wxchan's avatar
wxchan committed
157
    Activates early stopping.
158
159
160
    The model will train until the validation score stops improving.
    Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
    to continue training.
161
    Requires at least one validation data and one metric.
162
    If there's more than one, will check all of them. But the training data is ignored anyway.
163
    To check only the first metric set ``first_metric_only`` to True.
wxchan's avatar
wxchan committed
164
165
166
167

    Parameters
    ----------
    stopping_rounds : int
168
       The possible number of rounds without the trend occurrence.
169
170
    first_metric_only : bool, optional (default=False)
       Whether to use only the first metric for early stopping.
171
172
    verbose : bool, optional (default=True)
        Whether to print message with early stopping information.
wxchan's avatar
wxchan committed
173
174
175
176

    Returns
    -------
    callback : function
177
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
178
    """
wxchan's avatar
wxchan committed
179
180
    best_score = []
    best_iter = []
wxchan's avatar
wxchan committed
181
    best_score_list = []
wxchan's avatar
wxchan committed
182
    cmp_op = []
183
    enabled = [True]
184
    first_metric = ['']
wxchan's avatar
wxchan committed
185

186
    def _init(env):
187
188
189
190
191
192
193
        enabled[0] = not any((boost_alias in env.params
                              and env.params[boost_alias] == 'dart') for boost_alias in ('boosting',
                                                                                         'boosting_type',
                                                                                         'boost'))
        if not enabled[0]:
            warnings.warn('Early stopping is not available in dart mode')
            return
194
        if not env.evaluation_result_list:
195
196
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
197
198

        if verbose:
199
            msg = "Training until validation scores don't improve for {} rounds"
wxchan's avatar
wxchan committed
200
201
            print(msg.format(stopping_rounds))

202
203
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
wxchan's avatar
wxchan committed
204
205
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
206
            best_score_list.append(None)
wxchan's avatar
wxchan committed
207
208
209
210
211
212
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
213

214
215
216
217
218
219
220
221
222
    def _final_iteration_check(env, eval_name_splitted, i):
        if env.iteration == env.end_iteration - 1:
            if verbose:
                print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
                    best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                if first_metric_only:
                    print("Evaluated only: {}".format(eval_name_splitted[-1]))
            raise EarlyStopException(best_iter[i], best_score_list[i])

223
    def _callback(env):
wxchan's avatar
wxchan committed
224
        if not cmp_op:
225
            _init(env)
226
227
        if not enabled[0]:
            return
wxchan's avatar
wxchan committed
228
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
229
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
230
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
231
232
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
233
                best_score_list[i] = env.evaluation_result_list
234
235
236
237
238
239
240
241
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
242
243
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
wxchan's avatar
wxchan committed
244
245
                    print('Early stopping, best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
246
247
                    if first_metric_only:
                        print("Evaluated only: {}".format(eval_name_splitted[-1]))
wxchan's avatar
wxchan committed
248
                raise EarlyStopException(best_iter[i], best_score_list[i])
249
            _final_iteration_check(env, eval_name_splitted, i)
250
251
    _callback.order = 30
    return _callback