callback.py 6.62 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
wxchan's avatar
wxchan committed
3
from __future__ import absolute_import
4

wxchan's avatar
wxchan committed
5
import collections
wxchan's avatar
wxchan committed
6
from operator import gt, lt
wxchan's avatar
wxchan committed
7

wxchan's avatar
wxchan committed
8
9
from .compat import range_

wxchan's avatar
wxchan committed
10

wxchan's avatar
wxchan committed
11
12
13
14
15
16
17
18
19
20
21
class EarlyStopException(Exception):
    """Exception of early stopping.
    Parameters
    ----------
    best_iteration : int
        The best iteration stopped.
    """
    def __init__(self, best_iteration):
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration

wxchan's avatar
wxchan committed
22

wxchan's avatar
wxchan committed
23
24
25
26
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
27
     "params",
wxchan's avatar
wxchan committed
28
29
30
31
32
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
33

wxchan's avatar
wxchan committed
34
35
36
def _format_eval_result(value, show_stdv=True):
    """format metric string"""
    if len(value) == 4:
37
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
38
39
    elif len(value) == 5:
        if show_stdv:
40
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
41
        else:
42
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
43
    else:
44
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
45
46
47
48
49
50
51
52
53
54
55


def print_evaluation(period=1, show_stdv=True):
    """Create a callback that print evaluation result.

    Parameters
    ----------
    period : int
        The period to log the evaluation results

    show_stdv : bool, optional
56
        Whether show stdv if provided
wxchan's avatar
wxchan committed
57
58
59
60
61
62
63
64

    Returns
    -------
    callback : function
        A callback that print evaluation every period iterations.
    """
    def callback(env):
        """internal function"""
65
66
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
67
68
            print('[%d]\t%s' % (env.iteration + 1, result))
    callback.order = 10
wxchan's avatar
wxchan committed
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    return callback


def record_evaluation(eval_result):
    """Create a call back that records the evaluation history into eval_result.

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    if not isinstance(eval_result, dict):
86
        raise TypeError('Eval_result should be a dictionary')
wxchan's avatar
wxchan committed
87
88
89
90
    eval_result.clear()

    def init(env):
        """internal function"""
91
92
        for data_name, _, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.defaultdict(list))
wxchan's avatar
wxchan committed
93
94
95

    def callback(env):
        """internal function"""
96
        if not eval_result:
wxchan's avatar
wxchan committed
97
98
99
            init(env)
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
wxchan's avatar
wxchan committed
100
    callback.order = 20
wxchan's avatar
wxchan committed
101
102
103
    return callback


104
105
def reset_parameter(**kwargs):
    """Reset parameter after first iteration
wxchan's avatar
wxchan committed
106

107
    NOTE: the initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
108
109
110

    Parameters
    ----------
111
112
113
114
115
116
    **kwargs: value should be list or function
        List of parameters for each boosting round
        or a customized function that calculates learning_rate in terms of
        current number of round (e.g. yields learning rate decay)
        - list l: parameter = l[current_round]
        - function f: parameter = f(current_round)
wxchan's avatar
wxchan committed
117
118
119
120
121
122
123
    Returns
    -------
    callback : function
        The requested callback function.
    """
    def callback(env):
        """internal function"""
124
        new_parameters = {}
125
        for key, value in kwargs.items():
126
127
            if key in ['num_class', 'boosting_type', 'metric']:
                raise RuntimeError("cannot reset {} during training".format(repr(key)))
128
129
130
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key)))
131
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
132
            else:
133
134
135
136
137
138
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
wxchan's avatar
wxchan committed
139
    callback.before_iteration = True
wxchan's avatar
wxchan committed
140
    callback.order = 10
wxchan's avatar
wxchan committed
141
142
143
    return callback


144
def early_stopping(stopping_rounds, verbose=True):
wxchan's avatar
wxchan committed
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    """Create a callback that activates early stopping.
    Activates early stopping.
    Requires at least one validation data and one metric
    If there's more than one, will check all of them

    Parameters
    ----------
    stopping_rounds : int
       The stopping rounds before the trend occur.

    verbose : optional, bool
        Whether to print message about early stopping information.

    Returns
    -------
    callback : function
        The requested callback function.
    """
wxchan's avatar
wxchan committed
163
164
165
166
    best_score = []
    best_iter = []
    best_msg = []
    cmp_op = []
wxchan's avatar
wxchan committed
167

wxchan's avatar
wxchan committed
168
169
    def init(env):
        """internal function"""
170
        if not env.evaluation_result_list:
wxchan's avatar
wxchan committed
171
            raise ValueError('For early stopping, at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
172
173
174
175
176

        if verbose:
            msg = "Train until valid scores didn't improve in {} rounds."
            print(msg.format(stopping_rounds))

wxchan's avatar
wxchan committed
177
178
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
179
            if verbose:
wxchan's avatar
wxchan committed
180
181
182
183
184
185
186
                best_msg.append(None)
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
187
188
189

    def callback(env):
        """internal function"""
wxchan's avatar
wxchan committed
190
        if not cmp_op:
wxchan's avatar
wxchan committed
191
            init(env)
wxchan's avatar
wxchan committed
192
        best_msg_buffer = None
wxchan's avatar
wxchan committed
193
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
194
195
            score = env.evaluation_result_list[i][2]
            if cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
196
197
198
                best_score[i] = score
                best_iter[i] = env.iteration
                if verbose:
wxchan's avatar
wxchan committed
199
200
201
202
203
204
205
206
207
                    if not best_msg_buffer:
                        best_msg_buffer = '[%d]\t%s' % (
                            env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
                    best_msg[i] = best_msg_buffer
            elif env.iteration - best_iter[i] >= stopping_rounds:
                env.model.set_attr(best_iteration=str(best_iter[i]))
                if verbose:
                    print('Early stopping, best iteration is:\n' + best_msg[i])
                raise EarlyStopException(best_iter[i])
wxchan's avatar
wxchan committed
208
    callback.order = 30
wxchan's avatar
wxchan committed
209
    return callback