"...git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "bd28a3649a1ab16bb8cc10c2f81bfa9cf136778d"
callback.py 19.9 KB
Newer Older
wxchan's avatar
wxchan committed
1
# coding: utf-8
2
"""Callbacks library."""
3

4
5
from collections import OrderedDict
from dataclasses import dataclass
6
from functools import partial
7
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
wxchan's avatar
wxchan committed
8

9
10
11
12
13
14
15
16
from .basic import (
    Booster,
    _ConfigAliases,
    _LGBM_BoosterEvalMethodResultType,
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
    _log_info,
    _log_warning,
)
17
18
19

if TYPE_CHECKING:
    from .engine import CVBooster
wxchan's avatar
wxchan committed
20

21
__all__ = [
22
23
24
25
26
    "EarlyStopException",
    "early_stopping",
    "log_evaluation",
    "record_evaluation",
    "reset_parameter",
27
28
]

29
_EvalResultDict = Dict[str, Dict[str, List[Any]]]
30
_EvalResultTuple = Union[
31
    _LGBM_BoosterEvalMethodResultType,
32
    _LGBM_BoosterEvalMethodResultWithStandardDeviationType,
33
34
]
_ListOfEvalResultTuples = Union[
35
    List[_LGBM_BoosterEvalMethodResultType],
36
    List[_LGBM_BoosterEvalMethodResultWithStandardDeviationType],
37
38
]

wxchan's avatar
wxchan committed
39

wxchan's avatar
wxchan committed
40
class EarlyStopException(Exception):
41
42
43
44
45
    """Exception of early stopping.

    Raise this from a callback passed in via keyword argument ``callbacks``
    in ``cv()`` or ``train()`` to trigger early stopping.
    """
46

47
    def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
48
49
50
51
52
53
        """Create early stopping exception.

        Parameters
        ----------
        best_iteration : int
            The best iteration stopped.
54
            0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
55
56
        best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
            Scores for each metric, on each validation set, as of the best iteration.
57
        """
58
        super().__init__()
wxchan's avatar
wxchan committed
59
        self.best_iteration = best_iteration
wxchan's avatar
wxchan committed
60
        self.best_score = best_score
wxchan's avatar
wxchan committed
61

wxchan's avatar
wxchan committed
62

wxchan's avatar
wxchan committed
63
# Callback environment used by callbacks
64
65
66
67
68
69
70
@dataclass
class CallbackEnv:
    model: Union[Booster, "CVBooster"]
    params: Dict[str, Any]
    iteration: int
    begin_iteration: int
    end_iteration: int
71
    evaluation_result_list: Optional[_ListOfEvalResultTuples]
wxchan's avatar
wxchan committed
72

wxchan's avatar
wxchan committed
73

74
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool) -> str:
75
    """Format metric string."""
wxchan's avatar
wxchan committed
76
    if len(value) == 4:
77
        return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
78
79
    elif len(value) == 5:
        if show_stdv:
80
            return f"{value[0]}'s {value[1]}: {value[2]:g} + {value[4]:g}"  # type: ignore[misc]
wxchan's avatar
wxchan committed
81
        else:
82
            return f"{value[0]}'s {value[1]}: {value[2]:g}"
wxchan's avatar
wxchan committed
83
    else:
84
        raise ValueError("Wrong metric value")
wxchan's avatar
wxchan committed
85
86


87
88
89
90
91
92
93
94
95
96
97
98
class _LogEvaluationCallback:
    """Internal log evaluation callable class."""

    def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
        self.order = 10
        self.before_iteration = False

        self.period = period
        self.show_stdv = show_stdv

    def __call__(self, env: CallbackEnv) -> None:
        if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
99
100
            result = "\t".join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
            _log_info(f"[{env.iteration + 1}]\t{result}")
101
102
103


def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
104
105
    """Create a callback that logs the evaluation results.

106
107
108
109
110
111
    By default, standard output resource is used.
    Use ``register_logger()`` function to register a custom logger.

    Note
    ----
    Requires at least one validation data.
wxchan's avatar
wxchan committed
112
113
114

    Parameters
    ----------
115
    period : int, optional (default=1)
116
117
        The period to log the evaluation results.
        The last boosting stage or the boosting stage found by using ``early_stopping`` callback is also logged.
118
    show_stdv : bool, optional (default=True)
119
        Whether to log stdv (if provided).
wxchan's avatar
wxchan committed
120
121
122

    Returns
    -------
123
    callback : _LogEvaluationCallback
124
        The callback that logs the evaluation results every ``period`` boosting iteration(s).
wxchan's avatar
wxchan committed
125
    """
126
    return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
wxchan's avatar
wxchan committed
127
128


129
130
131
class _RecordEvaluationCallback:
    """Internal record evaluation callable class."""

132
    def __init__(self, eval_result: _EvalResultDict) -> None:
133
134
135
136
        self.order = 20
        self.before_iteration = False

        if not isinstance(eval_result, dict):
137
            raise TypeError("eval_result should be a dictionary")
138
139
140
        self.eval_result = eval_result

    def _init(self, env: CallbackEnv) -> None:
141
142
143
144
145
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
146
147
148
149
150
151
        self.eval_result.clear()
        for item in env.evaluation_result_list:
            if len(item) == 4:  # regular train
                data_name, eval_name = item[:2]
            else:  # cv
                data_name, eval_name = item[1].split()
152
            self.eval_result.setdefault(data_name, OrderedDict())
153
154
155
            if len(item) == 4:
                self.eval_result[data_name].setdefault(eval_name, [])
            else:
156
157
                self.eval_result[data_name].setdefault(f"{eval_name}-mean", [])
                self.eval_result[data_name].setdefault(f"{eval_name}-stdv", [])
158
159
160
161

    def __call__(self, env: CallbackEnv) -> None:
        if env.iteration == env.begin_iteration:
            self._init(env)
162
163
164
165
166
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "record_evaluation() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
167
168
169
170
171
172
        for item in env.evaluation_result_list:
            if len(item) == 4:
                data_name, eval_name, result = item[:3]
                self.eval_result[data_name][eval_name].append(result)
            else:
                data_name, eval_name = item[1].split()
173
                res_mean = item[2]
174
                res_stdv = item[4]  # type: ignore[misc]
175
176
                self.eval_result[data_name][f"{eval_name}-mean"].append(res_mean)
                self.eval_result[data_name][f"{eval_name}-stdv"].append(res_stdv)
177
178


179
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
180
    """Create a callback that records the evaluation history into ``eval_result``.
wxchan's avatar
wxchan committed
181
182
183
184

    Parameters
    ----------
    eval_result : dict
185
186
187
        Dictionary used to store all evaluation results of all validation sets.
        This should be initialized outside of your call to ``record_evaluation()`` and should be empty.
        Any initial contents of the dictionary will be deleted.
wxchan's avatar
wxchan committed
188

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        .. rubric:: Example

        With two validation sets named 'eval' and 'train', and one evaluation metric named 'logloss'
        this dictionary after finishing a model training process will have the following structure:

        .. code-block::

            {
             'train':
                 {
                  'logloss': [0.48253, 0.35953, ...]
                 },
             'eval':
                 {
                  'logloss': [0.480385, 0.357756, ...]
                 }
            }

wxchan's avatar
wxchan committed
207
208
    Returns
    -------
209
    callback : _RecordEvaluationCallback
210
        The callback that records the evaluation history into the passed dictionary.
wxchan's avatar
wxchan committed
211
    """
212
    return _RecordEvaluationCallback(eval_result=eval_result)
wxchan's avatar
wxchan committed
213
214


215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
class _ResetParameterCallback:
    """Internal reset parameter callable class."""

    def __init__(self, **kwargs: Union[list, Callable]) -> None:
        self.order = 10
        self.before_iteration = True

        self.kwargs = kwargs

    def __call__(self, env: CallbackEnv) -> None:
        new_parameters = {}
        for key, value in self.kwargs.items():
            if isinstance(value, list):
                if len(value) != env.end_iteration - env.begin_iteration:
                    raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
                new_param = value[env.iteration - env.begin_iteration]
            elif callable(value):
                new_param = value(env.iteration - env.begin_iteration)
            else:
234
235
236
237
                raise ValueError(
                    "Only list and callable values are supported "
                    "as a mapping from boosting round index to new parameter value."
                )
238
239
240
            if new_param != env.params.get(key, None):
                new_parameters[key] = new_param
        if new_parameters:
241
242
243
244
245
246
            if isinstance(env.model, Booster):
                env.model.reset_parameter(new_parameters)
            else:
                # CVBooster holds a list of Booster objects, each needs to be updated
                for booster in env.model.boosters:
                    booster.reset_parameter(new_parameters)
247
248
249
            env.params.update(new_parameters)


250
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
251
    """Create a callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
252

Nikita Titov's avatar
Nikita Titov committed
253
254
255
    .. note::

        The initial parameter will still take in-effect on first iteration.
wxchan's avatar
wxchan committed
256
257
258

    Parameters
    ----------
259
    **kwargs : value should be list or callable
260
        List of parameters for each boosting round
261
        or a callable that calculates the parameter in terms of
262
263
        current number of round (e.g. yields learning rate decay).
        If list lst, parameter = lst[current_round].
264
        If callable func, parameter = func(current_round).
265

wxchan's avatar
wxchan committed
266
267
    Returns
    -------
268
    callback : _ResetParameterCallback
269
        The callback that resets the parameter after the first iteration.
wxchan's avatar
wxchan committed
270
    """
271
    return _ResetParameterCallback(**kwargs)
wxchan's avatar
wxchan committed
272
273


274
275
276
277
278
279
280
281
class _EarlyStoppingCallback:
    """Internal early stopping callable class."""

    def __init__(
        self,
        stopping_rounds: int,
        first_metric_only: bool = False,
        verbose: bool = True,
282
        min_delta: Union[float, List[float]] = 0.0,
283
    ) -> None:
284
        self.enabled = _should_enable_early_stopping(stopping_rounds)
285

286
287
288
289
290
291
292
293
294
295
296
        self.order = 30
        self.before_iteration = False

        self.stopping_rounds = stopping_rounds
        self.first_metric_only = first_metric_only
        self.verbose = verbose
        self.min_delta = min_delta

        self._reset_storages()

    def _reset_storages(self) -> None:
297
298
        self.best_score: List[float] = []
        self.best_iter: List[int] = []
299
        self.best_score_list: List[_ListOfEvalResultTuples] = []
300
        self.cmp_op: List[Callable[[float, float], bool]] = []
301
        self.first_metric = ""
302
303
304
305
306
307
308

    def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score > best_score + delta

    def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
        return curr_score < best_score - delta

309
310
311
312
313
314
315
316
317
318
319
320
    def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool:
        """Check, by name, if a given Dataset is the training data."""
        # for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
        # and those metrics are considered for early stopping
        if ds_name == "cv_agg" and eval_name == "train":
            return True

        # for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
        if isinstance(env.model, Booster) and ds_name == env.model._train_data_name:
            return True

        return False
321

322
    def _init(self, env: CallbackEnv) -> None:
323
        if env.evaluation_result_list is None or env.evaluation_result_list == []:
324
            raise ValueError("For early stopping, at least one dataset and eval metric is required for evaluation")
325

326
        is_dart = any(env.params.get(alias, "") == "dart" for alias in _ConfigAliases.get("boosting"))
327
328
        if is_dart:
            self.enabled = False
329
            _log_warning("Early stopping is not available in dart mode")
330
            return
wxchan's avatar
wxchan committed
331

332
333
        # validation sets are guaranteed to not be identical to the training data in cv()
        if isinstance(env.model, Booster):
334
335
336
337
            only_train_set = len(env.evaluation_result_list) == 1 and self._is_train_set(
                ds_name=env.evaluation_result_list[0][0],
                eval_name=env.evaluation_result_list[0][1].split(" ")[0],
                env=env,
338
339
340
            )
            if only_train_set:
                self.enabled = False
341
                _log_warning("Only training set found, disabling early stopping.")
342
                return
343

344
345
        if self.verbose:
            _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
wxchan's avatar
wxchan committed
346

347
        self._reset_storages()
348

349
        n_metrics = len({m[1] for m in env.evaluation_result_list})
350
        n_datasets = len(env.evaluation_result_list) // n_metrics
351
352
        if isinstance(self.min_delta, list):
            if not all(t >= 0 for t in self.min_delta):
353
                raise ValueError("Values for early stopping min_delta must be non-negative.")
354
355
            if len(self.min_delta) == 0:
                if self.verbose:
356
                    _log_info("Disabling min_delta for early stopping.")
357
                deltas = [0.0] * n_datasets * n_metrics
358
359
            elif len(self.min_delta) == 1:
                if self.verbose:
360
                    _log_info(f"Using {self.min_delta[0]} as min_delta for all metrics.")
361
                deltas = self.min_delta * n_datasets * n_metrics
362
            else:
363
                if len(self.min_delta) != n_metrics:
364
                    raise ValueError("Must provide a single value for min_delta or as many as metrics.")
365
                if self.first_metric_only and self.verbose:
366
                    _log_info(f"Using only {self.min_delta[0]} as early stopping min_delta.")
367
                deltas = self.min_delta * n_datasets
368
        else:
369
            if self.min_delta < 0:
370
                raise ValueError("Early stopping min_delta must be non-negative.")
371
            if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
372
                _log_info(f"Using {self.min_delta} as min_delta for all metrics.")
373
            deltas = [self.min_delta] * n_datasets * n_metrics
374

375
        # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
376
        self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
377
        for eval_ret, delta in zip(env.evaluation_result_list, deltas):
378
            self.best_iter.append(0)
379
            if eval_ret[3]:  # greater is better
380
                self.best_score.append(float("-inf"))
381
                self.cmp_op.append(partial(self._gt_delta, delta=delta))
wxchan's avatar
wxchan committed
382
            else:
383
                self.best_score.append(float("inf"))
384
                self.cmp_op.append(partial(self._lt_delta, delta=delta))
wxchan's avatar
wxchan committed
385

386
    def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
387
        if env.iteration == env.end_iteration - 1:
388
            if self.verbose:
389
390
391
392
                best_score_str = "\t".join([_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]])
                _log_info(
                    "Did not meet early stopping. " f"Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}"
                )
393
                if self.first_metric_only:
394
                    _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
395
            raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
396

397
    def __call__(self, env: CallbackEnv) -> None:
398
        if env.iteration == env.begin_iteration:
399
400
            self._init(env)
        if not self.enabled:
401
            return
402
403
404
405
406
        if env.evaluation_result_list is None:
            raise RuntimeError(
                "early_stopping() callback enabled but no evaluation results found. This is a probably bug in LightGBM. "
                "Please report it at https://github.com/microsoft/LightGBM/issues"
            )
407
        # self.best_score_list is initialized to an empty list
408
        first_time_updating_best_score_list = self.best_score_list == []
409
        for i in range(len(env.evaluation_result_list)):
wxchan's avatar
wxchan committed
410
            score = env.evaluation_result_list[i][2]
411
            if first_time_updating_best_score_list or self.cmp_op[i](score, self.best_score[i]):
412
413
                self.best_score[i] = score
                self.best_iter[i] = env.iteration
414
415
416
417
                if first_time_updating_best_score_list:
                    self.best_score_list.append(env.evaluation_result_list)
                else:
                    self.best_score_list[i] = env.evaluation_result_list
418
419
            # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
            eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
420
            if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
421
                continue  # use only the first metric for early stopping
422
423
424
            if self._is_train_set(
                ds_name=env.evaluation_result_list[i][0],
                eval_name=eval_name_splitted[0],
425
                env=env,
426
            ):
427
                continue  # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
428
429
            elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
                if self.verbose:
430
431
432
                    eval_result_str = "\t".join(
                        [_format_eval_result(x, show_stdv=True) for x in self.best_score_list[i]]
                    )
433
434
                    _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
                    if self.first_metric_only:
435
                        _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
436
437
438
439
                raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
            self._final_iteration_check(env, eval_name_splitted, i)


440
441
442
443
444
445
446
447
448
449
450
451
def _should_enable_early_stopping(stopping_rounds: Any) -> bool:
    """Check if early stopping should be activated.

    This function will evaluate to True if the early stopping callback should be
    activated (i.e. stopping_rounds > 0).  It also provides an informative error if the
    type is not int.
    """
    if not isinstance(stopping_rounds, int):
        raise TypeError(f"early_stopping_round should be an integer. Got '{type(stopping_rounds).__name__}'")
    return stopping_rounds > 0


452
453
454
455
456
457
def early_stopping(
    stopping_rounds: int,
    first_metric_only: bool = False,
    verbose: bool = True,
    min_delta: Union[float, List[float]] = 0.0,
) -> _EarlyStoppingCallback:
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
    """Create a callback that activates early stopping.

    Activates early stopping.
    The model will train until the validation score doesn't improve by at least ``min_delta``.
    Validation score needs to improve at least every ``stopping_rounds`` round(s)
    to continue training.
    Requires at least one validation data and one metric.
    If there's more than one, will check all of them. But the training data is ignored anyway.
    To check only the first metric set ``first_metric_only`` to True.
    The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.

    Parameters
    ----------
    stopping_rounds : int
        The possible number of rounds without the trend occurrence.
    first_metric_only : bool, optional (default=False)
        Whether to use only the first metric for early stopping.
    verbose : bool, optional (default=True)
        Whether to log message with early stopping information.
        By default, standard output resource is used.
        Use ``register_logger()`` function to register a custom logger.
    min_delta : float or list of float, optional (default=0.0)
        Minimum improvement in score to keep training.
        If float, this single value is used for all metrics.
        If list, its length should match the total number of metrics.

484
485
        .. versionadded:: 4.0.0

486
487
488
489
490
    Returns
    -------
    callback : _EarlyStoppingCallback
        The callback that activates early stopping.
    """
491
492
493
494
495
496
    return _EarlyStoppingCallback(
        stopping_rounds=stopping_rounds,
        first_metric_only=first_metric_only,
        verbose=verbose,
        min_delta=min_delta,
    )