callback.py 7.46 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
3
"""Callbacks library."""
wxchan's avatar
wxchan committed
4
from __future__ import absolute_import
5

wxchan's avatar
wxchan committed
6
import collections
wxchan's avatar
wxchan committed
7
from operator import gt, lt
wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9
10
from .compat import range_

wxchan's avatar
wxchan committed
11

wxchan's avatar
wxchan committed
12
class EarlyStopException(Exception):
13
    """Exception of early stopping."""
14

wxchan's avatar
wxchan committed
15
    def __init__(self, best_iteration, best_score):
16
17
18
19
20
21
22
23
24
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
wxchan's avatar
wxchan committed
25
26
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
27
        self.best_score = best_score
wxchan's avatar
wxchan committed
28

wxchan's avatar
wxchan committed
29

wxchan's avatar
wxchan committed
30
31
32
33
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
34
     "params",
wxchan's avatar
wxchan committed
35
36
37
38
39
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
40

wxchan's avatar
wxchan committed
41
def _format_eval_result(value, show_stdv=True):
42
    """Format metric string."""
wxchan's avatar
wxchan committed
43
    if len(value) == 4:
44
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
45
46
    elif len(value) == 5:
        if show_stdv:
47
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
48
        else:
49
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
50
    else:
51
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
52
53
54


def print_evaluation(period=1, show_stdv=True):
55
    """Create a callback that prints the evaluation results.
wxchan's avatar
wxchan committed
56
57
58

    Parameters
    ----------
59
60
61
62
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
wxchan's avatar
wxchan committed
63
64
65
66

    Returns
    -------
    callback : function
67
        The callback that prints the evaluation results every ``period`` iteration(s).
wxchan's avatar
wxchan committed
68
    """
69
    def _callback(env):
70
71
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
72
            print('[%d]\t%s' % (env.iteration + 1, result))
73
74
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
75
76
77


def record_evaluation(eval_result):
78
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
79
80
81
82
83
84
85
86
87

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
88
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
89
90
    """
    if not isinstance(eval_result, dict):
91
        raise TypeError('Eval_result should be a dictionary')
wxchan's avatar
wxchan committed
92
93
    eval_result.clear()

94
    def _init(env):
95
96
        for data_name, _, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.defaultdict(list))
wxchan's avatar
wxchan committed
97

98
    def _callback(env):
99
        if not eval_result:
100
            _init(env)
wxchan's avatar
wxchan committed
101
102
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
103
104
    _callback.order = 20
    return _callback
wxchan's avatar
wxchan committed
105
106


107
def reset_parameter(**kwargs):
108
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
109

110
111
112
    Note
    ----
    The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
113
114
115

    Parameters
    ----------
116
    **kwargs : value should be list or function
117
        List of parameters for each boosting round
118
119
120
121
122
        or a customized function that calculates the parameter in terms of
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
        If function func, parameter = func(current_round).

wxchan's avatar
wxchan committed
123
124
125
    Returns
    -------
    callback : function
126
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
127
    """
128
    def _callback(env):
129
        new_parameters = {}
130
        for key, value in kwargs.items():
131
132
133
            if key in ['num_class', 'num_classes',
                       'boosting', 'boost', 'boosting_type',
                       'metric', 'metrics', 'metric_types']:
134
                raise RuntimeError("cannot reset {} during training".format(repr(key)))
135
136
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
137
138
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'."
                                     .format(repr(key)))
139
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
140
            else:
141
142
143
144
145
146
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
147
148
149
    _callback.before_iteration = True
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
150
151


152
def early_stopping(stopping_rounds, verbose=True):
wxchan's avatar
wxchan committed
153
    """Create a callback that activates early stopping.
154
155
156

    Note
    ----
wxchan's avatar
wxchan committed
157
    Activates early stopping.
158
159
160
    The model will train until the validation score stops improving.
    Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
    to continue training.
161
    Requires at least one validation data and one metric.
162
    If there's more than one, will check all of them. But the training data is ignored anyway.
wxchan's avatar
wxchan committed
163
164
165
166

    Parameters
    ----------
    stopping_rounds : int
167
168
169
       The possible number of rounds without the trend occurrence.
    verbose : bool, optional (default=True)
        Whether to print message with early stopping information.
wxchan's avatar
wxchan committed
170
171
172
173

    Returns
    -------
    callback : function
174
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
175
    """
wxchan's avatar
wxchan committed
176
177
    best_score = []
    best_iter = []
wxchan's avatar
wxchan committed
178
    best_score_list = []
wxchan's avatar
wxchan committed
179
    cmp_op = []
wxchan's avatar
wxchan committed
180

181
    def _init(env):
182
        if not env.evaluation_result_list:
183
184
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
185
186

        if verbose:
187
            msg = "Training until validation scores don't improve for {} rounds."
wxchan's avatar
wxchan committed
188
189
            print(msg.format(stopping_rounds))

wxchan's avatar
wxchan committed
190
191
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
192
            best_score_list.append(None)
wxchan's avatar
wxchan committed
193
194
195
196
197
198
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
199

200
    def _callback(env):
wxchan's avatar
wxchan committed
201
        if not cmp_op:
202
            _init(env)
wxchan's avatar
wxchan committed
203
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
204
205
            score = env.evaluation_result_list[i][2]
            if cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
206
207
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
208
                best_score_list[i] = env.evaluation_result_list
wxchan's avatar
wxchan committed
209
210
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
wxchan's avatar
wxchan committed
211
212
213
                    print('Early stopping, best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                raise EarlyStopException(best_iter[i], best_score_list[i])
214
215
216
217
218
            if env.iteration == env.end_iteration - 1:
                if verbose:
                    print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                raise EarlyStopException(best_iter[i], best_score_list[i])
219
220
    _callback.order = 30
    return _callback