"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "cc11525d261c29a0d965cbcbbda0aa884d35b46d"
Unverified Commit f77e0adf authored by Antoni Baum's avatar Antoni Baum Committed by GitHub
Browse files

[python] make `early_stopping` callback pickleable (#5012)

* Turn `early_stopping` into a Callable class

* Fix

* Lint

* Remove print

* Fix order

* Revert "Lint"

This reverts commit 7ca8b557572446888cf793c0082d9a7efd1e29a7.

* Apply suggestion from code review

* Nit

* Lint

* Move callable class outside the func for pickling

* Move _pickle and _unpickle to tests utils

* Add early stopping callback picklability test

* Nit

* Fix

* Lint

* Improve type hint

* Lint

* Lint

* Add cloudpickle to test_windows

* Update tests/python_package_test/test_engine.py

* Fix

* Apply suggestions from code review
parent eb686a76
...@@ -50,7 +50,7 @@ if ($env:TASK -eq "swig") { ...@@ -50,7 +50,7 @@ if ($env:TASK -eq "swig") {
Exit 0 Exit 0
} }
conda install -q -y -n $env:CONDA_ENV joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $? conda install -q -y -n $env:CONDA_ENV cloudpickle joblib matplotlib numpy pandas psutil pytest scikit-learn scipy ; Check-Output $?
# python-graphviz has to be installed separately to prevent conda from downgrading to pypy # python-graphviz has to be installed separately to prevent conda from downgrading to pypy
conda install -q -y -n $env:CONDA_ENV libxml2 python-graphviz ; Check-Output $? conda install -q -y -n $env:CONDA_ENV libxml2 python-graphviz ; Check-Output $?
......
...@@ -12,14 +12,6 @@ _EvalResultTuple = Union[ ...@@ -12,14 +12,6 @@ _EvalResultTuple = Union[
] ]
def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta
def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta
class EarlyStopException(Exception): class EarlyStopException(Exception):
"""Exception of early stopping.""" """Exception of early stopping."""
...@@ -199,156 +191,165 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable: ...@@ -199,156 +191,165 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
return _callback return _callback
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> Callable: class _EarlyStoppingCallback:
"""Create a callback that activates early stopping. """Internal early stopping callable class."""
Activates early stopping. def __init__(
The model will train until the validation score doesn't improve by at least ``min_delta``. self,
Validation score needs to improve at least every ``stopping_rounds`` round(s) stopping_rounds: int,
to continue training. first_metric_only: bool = False,
Requires at least one validation data and one metric. verbose: bool = True,
If there's more than one, will check all of them. But the training data is ignored anyway. min_delta: Union[float, List[float]] = 0.0
To check only the first metric set ``first_metric_only`` to True. ) -> None:
The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model. self.order = 30
self.before_iteration = False
Parameters self.stopping_rounds = stopping_rounds
---------- self.first_metric_only = first_metric_only
stopping_rounds : int self.verbose = verbose
The possible number of rounds without the trend occurrence. self.min_delta = min_delta
first_metric_only : bool, optional (default=False)
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to log message with early stopping information.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
min_delta : float or list of float, optional (default=0.0)
Minimum improvement in score to keep training.
If float, this single value is used for all metrics.
If list, its length should match the total number of metrics.
Returns self.enabled = True
------- self._reset_storages()
callback : callable
The callback that activates early stopping.
"""
best_score = []
best_iter = []
best_score_list: list = []
cmp_op = []
enabled = True
first_metric = ''
def _init(env: CallbackEnv) -> None: def _reset_storages(self) -> None:
nonlocal best_score self.best_score = []
nonlocal best_iter self.best_iter = []
nonlocal best_score_list self.best_score_list = []
nonlocal cmp_op self.cmp_op = []
nonlocal enabled self.first_metric = ''
nonlocal first_metric
enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta
def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool:
return curr_score < best_score - delta
def _init(self, env: CallbackEnv) -> None:
self.enabled = not any(env.params.get(boost_alias, "") == 'dart' for boost_alias
in _ConfigAliases.get("boosting")) in _ConfigAliases.get("boosting"))
if not enabled: if not self.enabled:
_log_warning('Early stopping is not available in dart mode') _log_warning('Early stopping is not available in dart mode')
return return
if not env.evaluation_result_list: if not env.evaluation_result_list:
raise ValueError('For early stopping, ' raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation') 'at least one dataset and eval metric is required for evaluation')
if stopping_rounds <= 0: if self.stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.") raise ValueError("stopping_rounds should be greater than zero.")
if verbose: if self.verbose:
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds") _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds")
# reset storages self._reset_storages()
best_score = []
best_iter = []
best_score_list = []
cmp_op = []
first_metric = ''
n_metrics = len(set(m[1] for m in env.evaluation_result_list)) n_metrics = len(set(m[1] for m in env.evaluation_result_list))
n_datasets = len(env.evaluation_result_list) // n_metrics n_datasets = len(env.evaluation_result_list) // n_metrics
if isinstance(min_delta, list): if isinstance(self.min_delta, list):
if not all(t >= 0 for t in min_delta): if not all(t >= 0 for t in self.min_delta):
raise ValueError('Values for early stopping min_delta must be non-negative.') raise ValueError('Values for early stopping min_delta must be non-negative.')
if len(min_delta) == 0: if len(self.min_delta) == 0:
if verbose: if self.verbose:
_log_info('Disabling min_delta for early stopping.') _log_info('Disabling min_delta for early stopping.')
deltas = [0.0] * n_datasets * n_metrics deltas = [0.0] * n_datasets * n_metrics
elif len(min_delta) == 1: elif len(self.min_delta) == 1:
if verbose: if self.verbose:
_log_info(f'Using {min_delta[0]} as min_delta for all metrics.') _log_info(f'Using {self.min_delta[0]} as min_delta for all metrics.')
deltas = min_delta * n_datasets * n_metrics deltas = self.min_delta * n_datasets * n_metrics
else: else:
if len(min_delta) != n_metrics: if len(self.min_delta) != n_metrics:
raise ValueError('Must provide a single value for min_delta or as many as metrics.') raise ValueError('Must provide a single value for min_delta or as many as metrics.')
if first_metric_only and verbose: if self.first_metric_only and self.verbose:
_log_info(f'Using only {min_delta[0]} as early stopping min_delta.') _log_info(f'Using only {self.min_delta[0]} as early stopping min_delta.')
deltas = min_delta * n_datasets deltas = self.min_delta * n_datasets
else: else:
if min_delta < 0: if self.min_delta < 0:
raise ValueError('Early stopping min_delta must be non-negative.') raise ValueError('Early stopping min_delta must be non-negative.')
if min_delta > 0 and n_metrics > 1 and not first_metric_only and verbose: if self.min_delta > 0 and n_metrics > 1 and not self.first_metric_only and self.verbose:
_log_info(f'Using {min_delta} as min_delta for all metrics.') _log_info(f'Using {self.min_delta} as min_delta for all metrics.')
deltas = [min_delta] * n_datasets * n_metrics deltas = [self.min_delta] * n_datasets * n_metrics
# split is needed for "<dataset type> <metric>" case (e.g. "train l1") # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
first_metric = env.evaluation_result_list[0][1].split(" ")[-1] self.first_metric = env.evaluation_result_list[0][1].split(" ")[-1]
for eval_ret, delta in zip(env.evaluation_result_list, deltas): for eval_ret, delta in zip(env.evaluation_result_list, deltas):
best_iter.append(0) self.best_iter.append(0)
best_score_list.append(None) self.best_score_list.append(None)
if eval_ret[3]: # greater is better if eval_ret[3]: # greater is better
best_score.append(float('-inf')) self.best_score.append(float('-inf'))
cmp_op.append(partial(_gt_delta, delta=delta)) self.cmp_op.append(partial(self._gt_delta, delta=delta))
else: else:
best_score.append(float('inf')) self.best_score.append(float('inf'))
cmp_op.append(partial(_lt_delta, delta=delta)) self.cmp_op.append(partial(self._lt_delta, delta=delta))
def _final_iteration_check(env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None: def _final_iteration_check(self, env: CallbackEnv, eval_name_splitted: List[str], i: int) -> None:
nonlocal best_iter
nonlocal best_score_list
if env.iteration == env.end_iteration - 1: if env.iteration == env.end_iteration - 1:
if verbose: if self.verbose:
best_score_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]]) best_score_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
_log_info('Did not meet early stopping. ' _log_info('Did not meet early stopping. '
f'Best iteration is:\n[{best_iter[i] + 1}]\t{best_score_str}') f'Best iteration is:\n[{self.best_iter[i] + 1}]\t{best_score_str}')
if first_metric_only: if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}") _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(best_iter[i], best_score_list[i]) raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
def _callback(env: CallbackEnv) -> None: def __call__(self, env: CallbackEnv) -> None:
nonlocal best_score
nonlocal best_iter
nonlocal best_score_list
nonlocal cmp_op
nonlocal enabled
nonlocal first_metric
if env.iteration == env.begin_iteration: if env.iteration == env.begin_iteration:
_init(env) self._init(env)
if not enabled: if not self.enabled:
return return
for i in range(len(env.evaluation_result_list)): for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] score = env.evaluation_result_list[i][2]
if best_score_list[i] is None or cmp_op[i](score, best_score[i]): if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
best_score[i] = score self.best_score[i] = score
best_iter[i] = env.iteration self.best_iter[i] = env.iteration
best_score_list[i] = env.evaluation_result_list self.best_score_list[i] = env.evaluation_result_list
# split is needed for "<dataset type> <metric>" case (e.g. "train l1") # split is needed for "<dataset type> <metric>" case (e.g. "train l1")
eval_name_splitted = env.evaluation_result_list[i][1].split(" ") eval_name_splitted = env.evaluation_result_list[i][1].split(" ")
if first_metric_only and first_metric != eval_name_splitted[-1]: if self.first_metric_only and self.first_metric != eval_name_splitted[-1]:
continue # use only the first metric for early stopping continue # use only the first metric for early stopping
if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train" if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)): or env.evaluation_result_list[i][0] == env.model._train_data_name)):
_final_iteration_check(env, eval_name_splitted, i) self._final_iteration_check(env, eval_name_splitted, i)
continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif env.iteration - best_iter[i] >= stopping_rounds: elif env.iteration - self.best_iter[i] >= self.stopping_rounds:
if verbose: if self.verbose:
eval_result_str = '\t'.join([_format_eval_result(x) for x in best_score_list[i]]) eval_result_str = '\t'.join([_format_eval_result(x) for x in self.best_score_list[i]])
_log_info(f"Early stopping, best iteration is:\n[{best_iter[i] + 1}]\t{eval_result_str}") _log_info(f"Early stopping, best iteration is:\n[{self.best_iter[i] + 1}]\t{eval_result_str}")
if first_metric_only: if self.first_metric_only:
_log_info(f"Evaluated only: {eval_name_splitted[-1]}") _log_info(f"Evaluated only: {eval_name_splitted[-1]}")
raise EarlyStopException(best_iter[i], best_score_list[i]) raise EarlyStopException(self.best_iter[i], self.best_score_list[i])
_final_iteration_check(env, eval_name_splitted, i) self._final_iteration_check(env, eval_name_splitted, i)
_callback.order = 30 # type: ignore
return _callback
def early_stopping(stopping_rounds: int, first_metric_only: bool = False, verbose: bool = True, min_delta: Union[float, List[float]] = 0.0) -> _EarlyStoppingCallback:
"""Create a callback that activates early stopping.
Activates early stopping.
The model will train until the validation score doesn't improve by at least ``min_delta``.
Validation score needs to improve at least every ``stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric set ``first_metric_only`` to True.
The index of iteration that has the best performance will be saved in the ``best_iteration`` attribute of a model.
Parameters
----------
stopping_rounds : int
The possible number of rounds without the trend occurrence.
first_metric_only : bool, optional (default=False)
Whether to use only the first metric for early stopping.
verbose : bool, optional (default=True)
Whether to log message with early stopping information.
By default, standard output resource is used.
Use ``register_logger()`` function to register a custom logger.
min_delta : float or list of float, optional (default=0.0)
Minimum improvement in score to keep training.
If float, this single value is used for all metrics.
If list, its length should match the total number of metrics.
Returns
-------
callback : _EarlyStoppingCallback
The callback that activates early stopping.
"""
return _EarlyStoppingCallback(stopping_rounds=stopping_rounds, first_metric_only=first_metric_only, verbose=verbose, min_delta=min_delta)
# coding: utf-8
import pytest
import lightgbm as lgb
from .utils import pickle_obj, unpickle_obj
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path):
callback = lgb.early_stopping(stopping_rounds=5)
tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.stopping_rounds == callback_from_disk.stopping_rounds
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
"""Tests for lightgbm.dask module""" """Tests for lightgbm.dask module"""
import inspect import inspect
import pickle
import random import random
import socket import socket
from itertools import groupby from itertools import groupby
...@@ -24,10 +23,8 @@ if machine() != 'x86_64': ...@@ -24,10 +23,8 @@ if machine() != 'x86_64':
if not lgb.compat.DASK_INSTALLED: if not lgb.compat.DASK_INSTALLED:
pytest.skip('Dask is not installed', allow_module_level=True) pytest.skip('Dask is not installed', allow_module_level=True)
import cloudpickle
import dask.array as da import dask.array as da
import dask.dataframe as dd import dask.dataframe as dd
import joblib
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import sklearn.utils.estimator_checks as sklearn_checks import sklearn.utils.estimator_checks as sklearn_checks
...@@ -37,7 +34,7 @@ from scipy.sparse import csc_matrix, csr_matrix ...@@ -37,7 +34,7 @@ from scipy.sparse import csc_matrix, csr_matrix
from scipy.stats import spearmanr from scipy.stats import spearmanr
from sklearn.datasets import make_blobs, make_regression from sklearn.datasets import make_blobs, make_regression
from .utils import make_ranking from .utils import make_ranking, pickle_obj, unpickle_obj
tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking'] tasks = ['binary-classification', 'multiclass-classification', 'regression', 'ranking']
distributed_training_algorithms = ['data', 'voting'] distributed_training_algorithms = ['data', 'voting']
...@@ -234,32 +231,6 @@ def _constant_metric(y_true, y_pred): ...@@ -234,32 +231,6 @@ def _constant_metric(y_true, y_pred):
return metric_name, value, is_higher_better return metric_name, value, is_higher_better
def _pickle(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
pickle.dump(obj, f)
elif serializer == 'joblib':
joblib.dump(obj, filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'wb') as f:
cloudpickle.dump(obj, f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def _unpickle(filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'rb') as f:
return pickle.load(f)
elif serializer == 'joblib':
return joblib.load(filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'rb') as f:
return cloudpickle.load(f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def _objective_least_squares(y_true, y_pred): def _objective_least_squares(y_true, y_pred):
grad = y_pred - y_true grad = y_pred - y_true
hess = np.ones(len(y_true)) hess = np.ones(len(y_true))
...@@ -1341,23 +1312,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -1341,23 +1312,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
assert getattr(local_model, "client", None) is None assert getattr(local_model, "client", None) is None
tmp_file = tmp_path / "model-1.pkl" tmp_file = tmp_path / "model-1.pkl"
_pickle( pickle_obj(
obj=dask_model, obj=dask_model,
filepath=tmp_file, filepath=tmp_file,
serializer=serializer serializer=serializer
) )
model_from_disk = _unpickle( model_from_disk = unpickle_obj(
filepath=tmp_file, filepath=tmp_file,
serializer=serializer serializer=serializer
) )
local_tmp_file = tmp_path / "local-model-1.pkl" local_tmp_file = tmp_path / "local-model-1.pkl"
_pickle( pickle_obj(
obj=local_model, obj=local_model,
filepath=local_tmp_file, filepath=local_tmp_file,
serializer=serializer serializer=serializer
) )
local_model_from_disk = _unpickle( local_model_from_disk = unpickle_obj(
filepath=local_tmp_file, filepath=local_tmp_file,
serializer=serializer serializer=serializer
) )
...@@ -1397,23 +1368,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici ...@@ -1397,23 +1368,23 @@ def test_model_and_local_version_are_picklable_whether_or_not_client_set_explici
local_model.client_ local_model.client_
tmp_file2 = tmp_path / "model-2.pkl" tmp_file2 = tmp_path / "model-2.pkl"
_pickle( pickle_obj(
obj=dask_model, obj=dask_model,
filepath=tmp_file2, filepath=tmp_file2,
serializer=serializer serializer=serializer
) )
fitted_model_from_disk = _unpickle( fitted_model_from_disk = unpickle_obj(
filepath=tmp_file2, filepath=tmp_file2,
serializer=serializer serializer=serializer
) )
local_tmp_file2 = tmp_path / "local-model-2.pkl" local_tmp_file2 = tmp_path / "local-model-2.pkl"
_pickle( pickle_obj(
obj=local_model, obj=local_model,
filepath=local_tmp_file2, filepath=local_tmp_file2,
serializer=serializer serializer=serializer
) )
local_fitted_model_from_disk = _unpickle( local_fitted_model_from_disk = unpickle_obj(
filepath=local_tmp_file2, filepath=local_tmp_file2,
serializer=serializer serializer=serializer
) )
......
# coding: utf-8 # coding: utf-8
import pickle
from functools import lru_cache from functools import lru_cache
import cloudpickle
import joblib
import numpy as np import numpy as np
import sklearn.datasets import sklearn.datasets
from sklearn.utils import check_random_state from sklearn.utils import check_random_state
...@@ -131,3 +134,29 @@ def sklearn_multiclass_custom_objective(y_true, y_pred): ...@@ -131,3 +134,29 @@ def sklearn_multiclass_custom_objective(y_true, y_pred):
factor = num_class / (num_class - 1) factor = num_class / (num_class - 1)
hess = factor * prob * (1 - prob) hess = factor * prob * (1 - prob)
return grad, hess return grad, hess
def pickle_obj(obj, filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'wb') as f:
pickle.dump(obj, f)
elif serializer == 'joblib':
joblib.dump(obj, filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'wb') as f:
cloudpickle.dump(obj, f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
def unpickle_obj(filepath, serializer):
if serializer == 'pickle':
with open(filepath, 'rb') as f:
return pickle.load(f)
elif serializer == 'joblib':
return joblib.load(filepath)
elif serializer == 'cloudpickle':
with open(filepath, 'rb') as f:
return cloudpickle.load(f)
else:
raise ValueError(f'Unrecognized serializer type: {serializer}')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment