callback.py 6.46 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
wxchan's avatar
wxchan committed
3
from __future__ import absolute_import
4

wxchan's avatar
wxchan committed
5
import collections
wxchan's avatar
wxchan committed
6
from operator import gt, lt
wxchan's avatar
wxchan committed
7

wxchan's avatar
wxchan committed
8
9
from .compat import range_

wxchan's avatar
wxchan committed
10

wxchan's avatar
wxchan committed
11
12
13
14
15
16
17
class EarlyStopException(Exception):
    """Exception of early stopping.
    Parameters
    ----------
    best_iteration : int
        The best iteration stopped.
    """
wxchan's avatar
wxchan committed
18
    def __init__(self, best_iteration, best_score):
wxchan's avatar
wxchan committed
19
20
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
21
        self.best_score = best_score
wxchan's avatar
wxchan committed
22

wxchan's avatar
wxchan committed
23

wxchan's avatar
wxchan committed
24
25
26
27
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
28
     "params",
wxchan's avatar
wxchan committed
29
30
31
32
33
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
34

wxchan's avatar
wxchan committed
35
36
37
def _format_eval_result(value, show_stdv=True):
    """format metric string"""
    if len(value) == 4:
38
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
39
40
    elif len(value) == 5:
        if show_stdv:
41
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
42
        else:
43
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
44
    else:
45
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
46
47
48
49
50
51
52
53
54
55
56


def print_evaluation(period=1, show_stdv=True):
    """Create a callback that print evaluation result.

    Parameters
    ----------
    period : int
        The period to log the evaluation results

    show_stdv : bool, optional
57
        Whether show stdv if provided
wxchan's avatar
wxchan committed
58
59
60
61
62
63
64
65

    Returns
    -------
    callback : function
        A callback that print evaluation every period iterations.
    """
    def callback(env):
        """internal function"""
66
67
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
68
69
            print('[%d]\t%s' % (env.iteration + 1, result))
    callback.order = 10
wxchan's avatar
wxchan committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
    return callback


def record_evaluation(eval_result):
    """Create a call back that records the evaluation history into eval_result.

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    if not isinstance(eval_result, dict):
87
        raise TypeError('Eval_result should be a dictionary')
wxchan's avatar
wxchan committed
88
89
90
91
    eval_result.clear()

    def init(env):
        """internal function"""
92
93
        for data_name, _, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.defaultdict(list))
wxchan's avatar
wxchan committed
94
95
96

    def callback(env):
        """internal function"""
97
        if not eval_result:
wxchan's avatar
wxchan committed
98
99
100
            init(env)
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
wxchan's avatar
wxchan committed
101
    callback.order = 20
wxchan's avatar
wxchan committed
102
103
104
    return callback


105
106
def reset_parameter(**kwargs):
    """Reset parameter after first iteration
wxchan's avatar
wxchan committed
107

108
    NOTE: the initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
109
110
111

    Parameters
    ----------
112
113
114
115
116
117
    **kwargs: value should be list or function
        List of parameters for each boosting round
        or a customized function that calculates learning_rate in terms of
        current number of round (e.g. yields learning rate decay)
        - list l: parameter = l[current_round]
        - function f: parameter = f(current_round)
wxchan's avatar
wxchan committed
118
119
120
121
122
123
124
    Returns
    -------
    callback : function
        The requested callback function.
    """
    def callback(env):
        """internal function"""
125
        new_parameters = {}
126
        for key, value in kwargs.items():
127
128
            if key in ['num_class', 'boosting_type', 'metric']:
                raise RuntimeError("cannot reset {} during training".format(repr(key)))
129
130
131
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key)))
132
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
133
            else:
134
135
136
137
138
139
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
wxchan's avatar
wxchan committed
140
    callback.before_iteration = True
wxchan's avatar
wxchan committed
141
    callback.order = 10
wxchan's avatar
wxchan committed
142
143
144
    return callback


145
def early_stopping(stopping_rounds, verbose=True):
wxchan's avatar
wxchan committed
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    """Create a callback that activates early stopping.
    Activates early stopping.
    Requires at least one validation data and one metric
    If there's more than one, will check all of them

    Parameters
    ----------
    stopping_rounds : int
       The stopping rounds before the trend occur.

    verbose : optional, bool
        Whether to print message about early stopping information.

    Returns
    -------
    callback : function
        The requested callback function.
    """
wxchan's avatar
wxchan committed
164
165
    best_score = []
    best_iter = []
wxchan's avatar
wxchan committed
166
    best_score_list = []
wxchan's avatar
wxchan committed
167
    cmp_op = []
wxchan's avatar
wxchan committed
168

wxchan's avatar
wxchan committed
169
170
    def init(env):
        """internal function"""
171
        if not env.evaluation_result_list:
wxchan's avatar
wxchan committed
172
            raise ValueError('For early stopping, at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
173
174

        if verbose:
175
            msg = "Training until validation scores don't improve for {} rounds."
wxchan's avatar
wxchan committed
176
177
            print(msg.format(stopping_rounds))

wxchan's avatar
wxchan committed
178
179
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
180
            best_score_list.append(None)
wxchan's avatar
wxchan committed
181
182
183
184
185
186
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
187
188
189

    def callback(env):
        """internal function"""
wxchan's avatar
wxchan committed
190
        if not cmp_op:
wxchan's avatar
wxchan committed
191
            init(env)
wxchan's avatar
wxchan committed
192
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
193
194
            score = env.evaluation_result_list[i][2]
            if cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
195
196
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
197
                best_score_list[i] = env.evaluation_result_list
wxchan's avatar
wxchan committed
198
199
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
wxchan's avatar
wxchan committed
200
201
202
                    print('Early stopping, best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                raise EarlyStopException(best_iter[i], best_score_list[i])
wxchan's avatar
wxchan committed
203
    callback.order = 30
wxchan's avatar
wxchan committed
204
    return callback