callback.py 8.89 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
wxchan's avatar
wxchan committed
3
import collections
4
import warnings
wxchan's avatar
wxchan committed
5
from operator import gt, lt
wxchan's avatar
wxchan committed
6

7
from .basic import _ConfigAliases
wxchan's avatar
wxchan committed
8

wxchan's avatar
wxchan committed
9

wxchan's avatar
wxchan committed
10
class EarlyStopException(Exception):
11
    """Exception of early stopping."""
12

wxchan's avatar
wxchan committed
13
    def __init__(self, best_iteration, best_score):
14
15
16
17
18
19
20
21
22
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
        best_score : float
            The score of the best iteration.
        """
23
        super().__init__()
wxchan's avatar
wxchan committed
24
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
25
        self.best_score = best_score
wxchan's avatar
wxchan committed
26

wxchan's avatar
wxchan committed
27

wxchan's avatar
wxchan committed
28
29
30
31
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
    "LightGBMCallbackEnv",
    ["model",
32
     "params",
wxchan's avatar
wxchan committed
33
34
35
36
37
     "iteration",
     "begin_iteration",
     "end_iteration",
     "evaluation_result_list"])

wxchan's avatar
wxchan committed
38

wxchan's avatar
wxchan committed
39
def _format_eval_result(value, show_stdv=True):
40
    """Format metric string."""
wxchan's avatar
wxchan committed
41
    if len(value) == 4:
42
        return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
43
44
    elif len(value) == 5:
        if show_stdv:
45
            return '%s\'s %s: %g + %g' % (value[0], value[1], value[2], value[4])
wxchan's avatar
wxchan committed
46
        else:
47
            return '%s\'s %s: %g' % (value[0], value[1], value[2])
wxchan's avatar
wxchan committed
48
    else:
49
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
50
51
52


def print_evaluation(period=1, show_stdv=True):
53
    """Create a callback that prints the evaluation results.
wxchan's avatar
wxchan committed
54
55
56

    Parameters
    ----------
57
58
59
60
    period : int, optional (default=1)
        The period to print the evaluation results.
    show_stdv : bool, optional (default=True)
        Whether to show stdv (if provided).
wxchan's avatar
wxchan committed
61
62
63
64

    Returns
    -------
    callback : function
65
        The callback that prints the evaluation results every ``period`` iteration(s).
wxchan's avatar
wxchan committed
66
    """
67
    def _callback(env):
68
69
        if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
            result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
wxchan's avatar
wxchan committed
70
            print('[%d]\t%s' % (env.iteration + 1, result))
71
72
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
73
74
75


def record_evaluation(eval_result):
76
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
77
78
79
80
81
82
83
84
85

    Parameters
    ----------
    eval_result : dict
       A dictionary to store the evaluation results.

    Returns
    -------
    callback : function
86
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
87
88
    """
    if not isinstance(eval_result, dict):
89
        raise TypeError('eval_result should be a dictionary')
wxchan's avatar
wxchan committed
90
91
    eval_result.clear()

92
    def _init(env):
93
94
95
        for data_name, eval_name, _, _ in env.evaluation_result_list:
            eval_result.setdefault(data_name, collections.OrderedDict())
            eval_result[data_name].setdefault(eval_name, [])
wxchan's avatar
wxchan committed
96

97
    def _callback(env):
98
        if not eval_result:
99
            _init(env)
wxchan's avatar
wxchan committed
100
101
        for data_name, eval_name, result, _ in env.evaluation_result_list:
            eval_result[data_name][eval_name].append(result)
102
103
    _callback.order = 20
    return _callback
wxchan's avatar
wxchan committed
104
105


106
def reset_parameter(**kwargs):
107
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
108

Nikita Titov's avatar
Nikita Titov committed
109
110
111
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
112
113
114

    Parameters
    ----------
115
    **kwargs : value should be list or function
116
        List of parameters for each boosting round
117
118
119
120
121
        or a customized function that calculates the parameter in terms of
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
        If function func, parameter = func(current_round).

wxchan's avatar
wxchan committed
122
123
124
    Returns
    -------
    callback : function
125
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
126
    """
127
    def _callback(env):
128
        new_parameters = {}
129
130
131
        for key, value in kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
132
133
                    raise ValueError("Length of list {} has to equal to 'num_boost_round'."
                                     .format(repr(key)))
134
                new_param = value[env.iteration - env.begin_iteration]
wxchan's avatar
wxchan committed
135
            else:
136
137
138
139
140
141
                new_param = value(env.iteration - env.begin_iteration)
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
            env.model.reset_parameter(new_parameters)
            env.params.update(new_parameters)
142
143
144
    _callback.before_iteration = True
    _callback.order = 10
    return _callback
wxchan's avatar
wxchan committed
145
146


147
def early_stopping(stopping_rounds, first_metric_only=False, verbose=True):
wxchan's avatar
wxchan committed
148
    """Create a callback that activates early stopping.
149

wxchan's avatar
wxchan committed
150
    Activates early stopping.
151
152
153
    The model will train until the validation score stops improving.
    Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
    to continue training.
154
    Requires at least one validation data and one metric.
155
    If there's more than one, will check all of them. But the training data is ignored anyway.
156
    To check only the first metric set ``first_metric_only`` to True.
wxchan's avatar
wxchan committed
157
158
159
160

    Parameters
    ----------
    stopping_rounds : int
161
       The possible number of rounds without the trend occurrence.
162
163
    first_metric_only : bool, optional (default=False)
       Whether to use only the first metric for early stopping.
164
165
    verbose : bool, optional (default=True)
        Whether to print message with early stopping information.
wxchan's avatar
wxchan committed
166
167
168
169

    Returns
    -------
    callback : function
170
        The callback that activates early stopping.
wxchan's avatar
wxchan committed
171
    """
wxchan's avatar
wxchan committed
172
173
    best_score = []
    best_iter = []
wxchan's avatar
wxchan committed
174
    best_score_list = []
wxchan's avatar
wxchan committed
175
    cmp_op = []
176
    enabled = [True]
177
    first_metric = ['']
wxchan's avatar
wxchan committed
178

179
    def _init(env):
180
181
        enabled[0] = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
                             in _ConfigAliases.get("boosting"))
182
183
184
        if not enabled[0]:
            warnings.warn('Early stopping is not available in dart mode')
            return
185
        if not env.evaluation_result_list:
186
187
            raise ValueError('For early stopping, '
                             'at least one dataset and eval metric is required for evaluation')
wxchan's avatar
wxchan committed
188
189

        if verbose:
190
            msg = "Training until validation scores don't improve for {} rounds"
wxchan's avatar
wxchan committed
191
192
            print(msg.format(stopping_rounds))

193
194
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
        first_metric[0] = env.evaluation_result_list[0][1].split(" ")[-1]
wxchan's avatar
wxchan committed
195
196
        for eval_ret in env.evaluation_result_list:
            best_iter.append(0)
wxchan's avatar
wxchan committed
197
            best_score_list.append(None)
wxchan's avatar
wxchan committed
198
199
200
201
202
203
            if eval_ret[3]:
                best_score.append(float('-inf'))
                cmp_op.append(gt)
            else:
                best_score.append(float('inf'))
                cmp_op.append(lt)
wxchan's avatar
wxchan committed
204

205
206
207
208
209
210
211
212
213
    def _final_iteration_check(env, eval_name_splitted, i):
        if env.iteration == env.end_iteration - 1:
            if verbose:
                print('Did not meet early stopping. Best iteration is:\n[%d]\t%s' % (
                    best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
                if first_metric_only:
                    print("Evaluated only: {}".format(eval_name_splitted[-1]))
            raise EarlyStopException(best_iter[i], best_score_list[i])

214
    def _callback(env):
wxchan's avatar
wxchan committed
215
        if not cmp_op:
216
            _init(env)
217
218
        if not enabled[0]:
            return
219
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
220
            score = env.evaluation_result_list[i][2]
Guolin Ke's avatar
Guolin Ke committed
221
            if best_score_list[i] is None or cmp_op[i](score, best_score[i]):
wxchan's avatar
wxchan committed
222
223
                best_score[i] = score
                best_iter[i] = env.iteration
wxchan's avatar
wxchan committed
224
                best_score_list[i] = env.evaluation_result_list
225
226
227
228
229
230
231
232
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
            if first_metric_only and first_metric[0] != eval_name_splitted[-1]:
                continue  # use only the first metric for early stopping
            if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
                 or env.evaluation_result_list[i][0] == env.model._train_data_name)):
                _final_iteration_check(env, eval_name_splitted, i)
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
wxchan's avatar
wxchan committed
233
234
            elif env.iteration - best_iter[i] >= stopping_rounds:
                if verbose:
wxchan's avatar
wxchan committed
235
236
                    print('Early stopping, best iteration is:\n[%d]\t%s' % (
                        best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
237
238
                    if first_metric_only:
                        print("Evaluated only: {}".format(eval_name_splitted[-1]))
wxchan's avatar
wxchan committed
239
                raise EarlyStopException(best_iter[i], best_score_list[i])
240
            _final_iteration_check(env, eval_name_splitted, i)
241
242
    _callback.order = 30
    return _callback