".github/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "9f1af051b44564eaab2bebe1612c6a52217bb32b"
callback.py 7.4 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
wxchan's avatar
wxchan committed
2
# pylint: disable = invalid-name, W0105, C0301
wxchan's avatar
wxchan committed
3
from __future__ import absolute_import
4

wxchan's avatar
wxchan committed
5
import collections
wxchan's avatar
wxchan committed
6
from operator import gt, lt
wxchan's avatar
wxchan committed
7

wxchan's avatar
wxchan committed
8
9
from .compat import range_

wxchan's avatar
wxchan committed
10

wxchan's avatar
wxchan committed
11
12
class EarlyStopException(Exception):
    """Exception of early stopping.
13

wxchan's avatar
wxchan committed
14
15
16
17
18
    Parameters
    ----------
    best_iteration : int
        The best iteration stopped.
    """
wxchan's avatar
wxchan committed
19
    def __init__(self, best_iteration, best_score):
wxchan's avatar
wxchan committed
20
21
        super(EarlyStopException, self).__init__()
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
22
        self.best_score = best_score
wxchan's avatar
wxchan committed
23

wxchan's avatar
wxchan committed
24

wxchan's avatar
wxchan committed
25
26
27
28
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
29
     "params",
wxchan's avatar
wxchan committed
30
31
32
33
34
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
35

wxchan's avatar
wxchan committed
36
37
38
def _format_eval_result(value, show_stdv=True):
    """format metric string"""
    if len(value) == 4:
39
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
40
41
    elif len(value) == 5:
        if show_stdv:
42
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
43
        else:
44
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
45
    else:
46
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
47
48
49


def print_evaluation(period=1, show_stdv=True):
50
    """Create a callback that prints the evaluation results.
wxchan's avatar
wxchan committed
51
52
53

    Parameters
    ----------
54
55
56
57
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
wxchan's avatar
wxchan committed
58
59
60
61

    Returns
    -------
    callback : function
62
        The callback that prints the evaluation results every ``period`` iteration(s).
wxchan's avatar
wxchan committed
63
64
65
    """
    def callback(env):
        """internal function"""
66
67
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
68
69
            print('[%d]\t%s' % (env.iteration + 1, result))
    callback.order = 10
wxchan's avatar
wxchan committed
70
71
72
73
    return callback


def record_evaluation(eval_result):
74
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
75
76
77
78
79
80
81
82
83

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
84
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
85
86
    """
    if not isinstance(eval_result, dict):
87
        raise TypeError('Eval_result should be a dictionary')
wxchan's avatar
wxchan committed
88
89
90
91
    eval_result.clear()

    def init(env):
        """internal function"""
92
93
        for data_name, _, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.defaultdict(list))
wxchan's avatar
wxchan committed
94
95
96

    def callback(env):
        """internal function"""
97
        if not eval_result:
wxchan's avatar
wxchan committed
98
99
100
            init(env)
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
wxchan's avatar
wxchan committed
101
    callback.order = 20
wxchan's avatar
wxchan committed
102
103
104
    return callback


105
def reset_parameter(**kwargs):
106
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
107

108
109
110
    Note
    ----
    The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
111
112
113

    Parameters
    ----------
114
115
    **kwargs: value should be list or function
        List of parameters for each boosting round
116
117
118
119
120
        or a customized function that calculates the parameter in terms of
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
        If function func, parameter = func(current_round).

wxchan's avatar
wxchan committed
121
122
123
    Returns
    -------
    callback : function
124
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
125
126
127
    """
    def callback(env):
        """internal function"""
128
        new_parameters = {}
129
        for key, value in kwargs.items():
130
131
132
            if key in ['num_class', 'num_classes',
                       'boosting', 'boost', 'boosting_type',
                       'metric', 'metrics', 'metric_types']:
133
                raise RuntimeError("cannot reset {} during training".format(repr(key)))
134
135
136
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'.".format(repr(key)))
137
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
138
            else:
139
140
141
142
143
144
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
wxchan's avatar
wxchan committed
145
    callback.before_iteration = True
wxchan's avatar
wxchan committed
146
    callback.order = 10
wxchan's avatar
wxchan committed
147
148
149
    return callback


150
def early_stopping(stopping_rounds, verbose=True):
wxchan's avatar
wxchan committed
151
    """Create a callback that activates early stopping.
152
153
154

    Note
    ----
wxchan's avatar
wxchan committed
155
    Activates early stopping.
156
157
158
    The model will train until the validation score stops improving.
    Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
    to continue training.
159
    Requires at least one validation data and one metric.
160
    If there's more than one, will check all of them. But the training data is ignored anyway.
wxchan's avatar
wxchan committed
161
162
163
164

    Parameters
    ----------
    stopping_rounds : int
165
       The possible number of rounds without the trend occurrence.
wxchan's avatar
wxchan committed
166

167
168
    verbose : bool, optional (default=True)
        Whether to print message with early stopping information.
wxchan's avatar
wxchan committed
169
170
171
172

    Returns
    -------
    callback : function
173
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
174
    """
wxchan's avatar
wxchan committed
175
176
    best_score = []
    best_iter = []
wxchan's avatar
wxchan committed
177
    best_score_list = []
wxchan's avatar
wxchan committed
178
    cmp_op = []
wxchan's avatar
wxchan committed
179

wxchan's avatar
wxchan committed
180
181
    def init(env):
        """internal function"""
182
        if not env.evaluation_result_list:
wxchan's avatar
wxchan committed
183
            raise ValueError('For early stopping, at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
184
185

        if verbose:
186
            msg = "Training until validation scores don't improve for {} rounds."
wxchan's avatar
wxchan committed
187
188
            print(msg.format(stopping_rounds))

wxchan's avatar
wxchan committed
189
190
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
191
            best_score_list.append(None)
wxchan's avatar
wxchan committed
192
193
194
195
196
197
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
198
199
200

    def callback(env):
        """internal function"""
wxchan's avatar
wxchan committed
201
        if not cmp_op:
wxchan's avatar
wxchan committed
202
            init(env)
wxchan's avatar
wxchan committed
203
        for i in range_(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
204
205
            score = env.evaluation_result_list[i][2]
            if cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
206
207
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
208
                best_score_list[i] = env.evaluation_result_list
wxchan's avatar
wxchan committed
209
210
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
wxchan's avatar
wxchan committed
211
212
213
                    print('Early stopping, best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                raise EarlyStopException(best_iter[i], best_score_list[i])
214
215
216
217
218
            if env.iteration == env.end_iteration - 1:
                if verbose:
                    print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                raise EarlyStopException(best_iter[i], best_score_list[i])
wxchan's avatar
wxchan committed
219
    callback.order = 30
wxchan's avatar
wxchan committed
220
    return callback