"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "5dfe7168d42898b66da3513eb8cab68ef2b23eeb"
Unverified Commit 8b33e776 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] make `log_evaluation` callback pickleable (#5101)

* make `log_evaluation` callback pickleable

* make callback tests stricter
parent 417c732c
...@@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str: ...@@ -54,7 +54,23 @@ def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
raise ValueError("Wrong metric value") raise ValueError("Wrong metric value")
def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable: class _LogEvaluationCallback:
"""Internal log evaluation callable class."""
def __init__(self, period: int = 1, show_stdv: bool = True) -> None:
self.order = 10
self.before_iteration = False
self.period = period
self.show_stdv = show_stdv
def __call__(self, env: CallbackEnv) -> None:
if self.period > 0 and env.evaluation_result_list and (env.iteration + 1) % self.period == 0:
result = '\t'.join([_format_eval_result(x, self.show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')
def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCallback:
"""Create a callback that logs the evaluation results. """Create a callback that logs the evaluation results.
By default, standard output resource is used. By default, standard output resource is used.
...@@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable: ...@@ -74,15 +90,10 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> Callable:
Returns Returns
------- -------
callback : callable callback : _LogEvaluationCallback
The callback that logs the evaluation results every ``period`` boosting iteration(s). The callback that logs the evaluation results every ``period`` boosting iteration(s).
""" """
def _callback(env: CallbackEnv) -> None: return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
if period > 0 and env.evaluation_result_list and (env.iteration + 1) % period == 0:
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
_log_info(f'[{env.iteration + 1}]\t{result}')
_callback.order = 10 # type: ignore
return _callback
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
......
...@@ -8,7 +8,8 @@ from .utils import pickle_obj, unpickle_obj ...@@ -8,7 +8,8 @@ from .utils import pickle_obj, unpickle_obj
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) @pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path): def test_early_stopping_callback_is_picklable(serializer, tmp_path):
callback = lgb.early_stopping(stopping_rounds=5) rounds = 5
callback = lgb.early_stopping(stopping_rounds=rounds)
tmp_file = tmp_path / "early_stopping.pkl" tmp_file = tmp_path / "early_stopping.pkl"
pickle_obj( pickle_obj(
obj=callback, obj=callback,
...@@ -20,3 +21,22 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path): ...@@ -20,3 +21,22 @@ def test_early_stopping_callback_is_picklable(serializer, tmp_path):
serializer=serializer serializer=serializer
) )
assert callback.stopping_rounds == callback_from_disk.stopping_rounds assert callback.stopping_rounds == callback_from_disk.stopping_rounds
assert callback.stopping_rounds == rounds
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_log_evaluation_callback_is_picklable(serializer, tmp_path):
periods = 42
callback = lgb.log_evaluation(period=periods)
tmp_file = tmp_path / "log_evaluation.pkl"
pickle_obj(
obj=callback,
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.period == callback_from_disk.period
assert callback.period == periods
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment