callback.py 6.17 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
wxchan's avatar
wxchan committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from __future__ import absolute_import
import collections

class EarlyStopException(Exception):
    """Exception of early stopping.
    Parameters
    ----------
    best_iteration : int
        The best iteration stopped.
    """
    def __init__(self, best_iteration):
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration

# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
     "cvfolds",
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

def _format_eval_result(value, show_stdv=True):
    """format metric string"""
    if len(value) == 4:
        return '%s\'s %s:%g' % (value[0], value[1], value[2])
    elif len(value) == 5:
        if show_stdv:
            return '%s\'s %s:%g+%g' % (value[0], value[1], value[2], value[4])
        else:
            return '%s\'s %s:%g' % (value[0], value[1], value[2])
    else:
37
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
38
39
40
41
42
43
44
45
46
47
48


def print_evaluation(period=1, show_stdv=True):
    """Create a callback that print evaluation result.

    Parameters
    ----------
    period : int
        The period to log the evaluation results

    show_stdv : bool, optional
49
        Whether show stdv if provided
wxchan's avatar
wxchan committed
50
51
52
53
54
55
56
57

    Returns
    -------
    callback : function
        A callback that print evaluation every period iterations.
    """
    def callback(env):
        """internal function"""
58
        if not env.evaluation_result_list or period <= 0:
wxchan's avatar
wxchan committed
59
            return
wxchan's avatar
wxchan committed
60
        if (env.iteration + 1) % period == 0:
wxchan's avatar
wxchan committed
61
62
            result = '\t'.join([_format_eval_result(x, show_stdv) \
                for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
63
64
            print('[%d]\t%s' % (env.iteration + 1, result))
    callback.order = 10
wxchan's avatar
wxchan committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    return callback


def record_evaluation(eval_result):
    """Create a call back that records the evaluation history into eval_result.

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    if not isinstance(eval_result, dict):
82
        raise TypeError('Eval_result should be a dictionary')
wxchan's avatar
wxchan committed
83
84
85
86
    eval_result.clear()

    def init(env):
        """internal function"""
87
88
        for data_name, _, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.defaultdict(list))
wxchan's avatar
wxchan committed
89
90
91

    def callback(env):
        """internal function"""
92
        if not eval_result:
wxchan's avatar
wxchan committed
93
94
95
            init(env)
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
wxchan's avatar
wxchan committed
96
    callback.order = 20
wxchan's avatar
wxchan committed
97
98
99
    return callback


100
101
def reset_parameter(**kwargs):
    """Reset parameter after first iteration
wxchan's avatar
wxchan committed
102

103
    NOTE: the initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
104
105
106

    Parameters
    ----------
107
108
109
110
111
112
    **kwargs: value should be list or function
        List of parameters for each boosting round
        or a customized function that calculates learning_rate in terms of
        current number of round (e.g. yields learning rate decay)
        - list l: parameter = l[current_round]
        - function f: parameter = f(current_round)
wxchan's avatar
wxchan committed
113
114
115
116
117
118
119
    Returns
    -------
    callback : function
        The requested callback function.
    """
    def callback(env):
        """internal function"""
120
121
122
123
124
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key)))
                env.model.reset_parameter({key: value[env.iteration - env.begin_iteration]})
wxchan's avatar
wxchan committed
125
            else:
126
                env.model.reset_parameter({key: value(env.iteration - env.begin_iteration)})
wxchan's avatar
wxchan committed
127
    callback.before_iteration = True
wxchan's avatar
wxchan committed
128
    callback.order = 10
wxchan's avatar
wxchan committed
129
130
131
    return callback


132
def early_stopping(stopping_rounds, verbose=True):
wxchan's avatar
wxchan committed
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    """Create a callback that activates early stopping.
    Activates early stopping.
    Requires at least one validation data and one metric
    If there's more than one, will check all of them

    Parameters
    ----------
    stopping_rounds : int
       The stopping rounds before the trend occur.

    verbose : optional, bool
        Whether to print message about early stopping information.

    Returns
    -------
    callback : function
        The requested callback function.
    """
    factor_to_bigger_better = {}
    best_score = {}
    best_iter = {}
    best_msg = {}
    def init(env):
        """internal function"""
157
        if not env.evaluation_result_list:
158
            raise ValueError('For early stopping, at least one dataset is required for evaluation')
wxchan's avatar
wxchan committed
159
160
161
162
163
164
165
166
167
168

        if verbose:
            msg = "Train until valid scores didn't improve in {} rounds."
            print(msg.format(stopping_rounds))

        for i in range(len(env.evaluation_result_list)):
            best_score[i] = float('-inf')
            best_iter[i] = 0
            if verbose:
                best_msg[i] = ""
169
            factor_to_bigger_better[i] = 1.0 if env.evaluation_result_list[i][3] else -1.0
wxchan's avatar
wxchan committed
170
171
172

    def callback(env):
        """internal function"""
173
        if not best_score:
wxchan's avatar
wxchan committed
174
175
176
177
178
179
180
            init(env)
        for i in range(len(env.evaluation_result_list)):
            score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i]
            if score > best_score[i]:
                best_score[i] = score
                best_iter[i] = env.iteration
                if verbose:
wxchan's avatar
wxchan committed
181
                    best_msg[i] = '[%d]\t%s' % (env.iteration + 1, \
wxchan's avatar
wxchan committed
182
183
184
185
186
187
                        '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
            else:
                if env.iteration - best_iter[i] >= stopping_rounds:
                    if env.model is not None:
                        env.model.set_attr(best_iteration=str(best_iter[i]))
                    if verbose:
188
                        print('Early stopping, best iteration is:')
189
                        print(best_msg[i])
wxchan's avatar
wxchan committed
190
                    raise EarlyStopException(best_iter[i])
wxchan's avatar
wxchan committed
191
    callback.order = 30
wxchan's avatar
wxchan committed
192
    return callback