callback.py 6.56 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
wxchan's avatar
wxchan committed
3
from __future__ import absolute_import
4

wxchan's avatar
wxchan committed
5
6
import collections

wxchan's avatar
wxchan committed
7
8
from .compat import range_

wxchan's avatar
wxchan committed
9

wxchan's avatar
wxchan committed
10
11
12
13
14
15
16
17
18
19
20
class EarlyStopException(Exception):
    """Exception of early stopping.
    Parameters
    ----------
    best_iteration : int
        The best iteration stopped.
    """
    def __init__(self, best_iteration):
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration

wxchan's avatar
wxchan committed
21

wxchan's avatar
wxchan committed
22
23
24
25
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
26
     "params",
wxchan's avatar
wxchan committed
27
28
29
30
31
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
32

wxchan's avatar
wxchan committed
33
34
35
def _format_eval_result(value, show_stdv=True):
    """format metric string"""
    if len(value) == 4:
36
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
37
38
    elif len(value) == 5:
        if show_stdv:
39
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
40
        else:
41
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
42
    else:
43
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
44
45
46
47
48
49
50
51
52
53
54


def print_evaluation(period=1, show_stdv=True):
    """Create a callback that print evaluation result.

    Parameters
    ----------
    period : int
        The period to log the evaluation results

    show_stdv : bool, optional
55
        Whether show stdv if provided
wxchan's avatar
wxchan committed
56
57
58
59
60
61
62
63

    Returns
    -------
    callback : function
        A callback that print evaluation every period iterations.
    """
    def callback(env):
        """internal function"""
64
65
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
66
67
            print('[%d]\t%s' % (env.iteration + 1, result))
    callback.order = 10
wxchan's avatar
wxchan committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    return callback


def record_evaluation(eval_result):
    """Create a call back that records the evaluation history into eval_result.

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    if not isinstance(eval_result, dict):
85
        raise TypeError('Eval_result should be a dictionary')
wxchan's avatar
wxchan committed
86
87
88
89
    eval_result.clear()

    def init(env):
        """internal function"""
90
91
        for data_name, _, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.defaultdict(list))
wxchan's avatar
wxchan committed
92
93
94

    def callback(env):
        """internal function"""
95
        if not eval_result:
wxchan's avatar
wxchan committed
96
97
98
            init(env)
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
wxchan's avatar
wxchan committed
99
    callback.order = 20
wxchan's avatar
wxchan committed
100
101
102
    return callback


103
104
def reset_parameter(**kwargs):
    """Reset parameter after first iteration
wxchan's avatar
wxchan committed
105

106
    NOTE: the initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
107
108
109

    Parameters
    ----------
110
111
112
113
114
115
    **kwargs: value should be list or function
        List of parameters for each boosting round
        or a customized function that calculates learning_rate in terms of
        current number of round (e.g. yields learning rate decay)
        - list l: parameter = l[current_round]
        - function f: parameter = f(current_round)
wxchan's avatar
wxchan committed
116
117
118
119
120
121
122
    Returns
    -------
    callback : function
        The requested callback function.
    """
    def callback(env):
        """internal function"""
123
        new_parameters = {}
124
        for key, value in kwargs.items():
125
126
            if key in ['num_class', 'boosting_type', 'metric']:
                raise RuntimeError("cannot reset {} during training".format(repr(key)))
127
128
129
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key)))
130
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
131
            else:
132
133
134
135
136
137
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
wxchan's avatar
wxchan committed
138
    callback.before_iteration = True
wxchan's avatar
wxchan committed
139
    callback.order = 10
wxchan's avatar
wxchan committed
140
141
142
    return callback


143
def early_stopping(stopping_rounds, verbose=True):
wxchan's avatar
wxchan committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    """Create a callback that activates early stopping.
    Activates early stopping.
    Requires at least one validation data and one metric
    If there's more than one, will check all of them

    Parameters
    ----------
    stopping_rounds : int
       The stopping rounds before the trend occur.

    verbose : optional, bool
        Whether to print message about early stopping information.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    factor_to_bigger_better = {}
    best_score = {}
    best_iter = {}
    best_msg = {}
wxchan's avatar
wxchan committed
166

wxchan's avatar
wxchan committed
167
168
    def init(env):
        """internal function"""
169
        if not env.evaluation_result_list:
wxchan's avatar
wxchan committed
170
            raise ValueError('For early stopping, at least one dataset or eval metric is required for evaluation')
wxchan's avatar
wxchan committed
171
172
173
174
175

        if verbose:
            msg = "Train until valid scores didn't improve in {} rounds."
            print(msg.format(stopping_rounds))

wxchan's avatar
wxchan committed
176
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
177
178
179
180
            best_score[i] = float('-inf')
            best_iter[i] = 0
            if verbose:
                best_msg[i] = ""
181
            factor_to_bigger_better[i] = 1.0 if env.evaluation_result_list[i][3] else -1.0
wxchan's avatar
wxchan committed
182
183
184

    def callback(env):
        """internal function"""
185
        if not best_score:
wxchan's avatar
wxchan committed
186
            init(env)
wxchan's avatar
wxchan committed
187
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
188
189
190
191
192
            score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i]
            if score > best_score[i]:
                best_score[i] = score
                best_iter[i] = env.iteration
                if verbose:
wxchan's avatar
wxchan committed
193
194
195
196
197
                    best_msg[i] = '[%d]\t%s' % (
                        env.iteration + 1, '\t'.join(
                            [_format_eval_result(x) for x in env.evaluation_result_list]
                        )
                    )
wxchan's avatar
wxchan committed
198
199
            else:
                if env.iteration - best_iter[i] >= stopping_rounds:
200
                    env.model.set_attr(best_iteration=str(best_iter[i]))
wxchan's avatar
wxchan committed
201
                    if verbose:
202
                        print('Early stopping, best iteration is:')
203
                        print(best_msg[i])
wxchan's avatar
wxchan committed
204
                    raise EarlyStopException(best_iter[i])
wxchan's avatar
wxchan committed
205
    callback.order = 30
wxchan's avatar
wxchan committed
206
    return callback