Unverified Commit 60244e4a authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] make `record_evaluation` callback pickleable (#5107)

* make `log_evaluation` callback pickleable

* make callback tests stricter

* make `record_evaluation` callback picklable
parent 8b33e776
...@@ -96,6 +96,45 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCal ...@@ -96,6 +96,45 @@ def log_evaluation(period: int = 1, show_stdv: bool = True) -> _LogEvaluationCal
return _LogEvaluationCallback(period=period, show_stdv=show_stdv) return _LogEvaluationCallback(period=period, show_stdv=show_stdv)
class _RecordEvaluationCallback:
"""Internal record evaluation callable class."""
def __init__(self, eval_result: Dict[str, Dict[str, List[Any]]]) -> None:
self.order = 20
self.before_iteration = False
if not isinstance(eval_result, dict):
raise TypeError('eval_result should be a dictionary')
self.eval_result = eval_result
def _init(self, env: CallbackEnv) -> None:
self.eval_result.clear()
for item in env.evaluation_result_list:
if len(item) == 4: # regular train
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
self.eval_result.setdefault(data_name, collections.OrderedDict())
if len(item) == 4:
self.eval_result[data_name].setdefault(eval_name, [])
else:
self.eval_result[data_name].setdefault(f'{eval_name}-mean', [])
self.eval_result[data_name].setdefault(f'{eval_name}-stdv', [])
def __call__(self, env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
self._init(env)
for item in env.evaluation_result_list:
if len(item) == 4:
data_name, eval_name, result = item[:3]
self.eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
"""Create a callback that records the evaluation history into ``eval_result``. """Create a callback that records the evaluation history into ``eval_result``.
...@@ -126,40 +165,10 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable: ...@@ -126,40 +165,10 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
Returns Returns
------- -------
callback : callable callback : _RecordEvaluationCallback
The callback that records the evaluation history into the passed dictionary. The callback that records the evaluation history into the passed dictionary.
""" """
if not isinstance(eval_result, dict): return _RecordEvaluationCallback(eval_result=eval_result)
raise TypeError('eval_result should be a dictionary')
def _init(env: CallbackEnv) -> None:
eval_result.clear()
for item in env.evaluation_result_list:
if len(item) == 4: # regular train
data_name, eval_name = item[:2]
else: # cv
data_name, eval_name = item[1].split()
eval_result.setdefault(data_name, collections.OrderedDict())
if len(item) == 4:
eval_result[data_name].setdefault(eval_name, [])
else:
eval_result[data_name].setdefault(f'{eval_name}-mean', [])
eval_result[data_name].setdefault(f'{eval_name}-stdv', [])
def _callback(env: CallbackEnv) -> None:
if env.iteration == env.begin_iteration:
_init(env)
for item in env.evaluation_result_list:
if len(item) == 4:
data_name, eval_name, result = item[:3]
eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
_callback.order = 20 # type: ignore
return _callback
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable: def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
......
...@@ -5,38 +5,51 @@ import lightgbm as lgb ...@@ -5,38 +5,51 @@ import lightgbm as lgb
from .utils import pickle_obj, unpickle_obj from .utils import pickle_obj, unpickle_obj
SERIALIZERS = ["pickle", "joblib", "cloudpickle"]
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"])
def test_early_stopping_callback_is_picklable(serializer, tmp_path): def pickle_and_unpickle_object(obj, serializer):
with lgb.basic._TempFile() as tmp_file:
pickle_obj(
obj=obj,
filepath=tmp_file.name,
serializer=serializer
)
obj_from_disk = unpickle_obj(
filepath=tmp_file.name,
serializer=serializer
)
return obj_from_disk
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
rounds = 5 rounds = 5
callback = lgb.early_stopping(stopping_rounds=rounds) callback = lgb.early_stopping(stopping_rounds=rounds)
tmp_file = tmp_path / "early_stopping.pkl" callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
pickle_obj( assert callback_from_disk.order == 30
obj=callback, assert callback_from_disk.before_iteration is False
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.stopping_rounds == callback_from_disk.stopping_rounds assert callback.stopping_rounds == callback_from_disk.stopping_rounds
assert callback.stopping_rounds == rounds assert callback.stopping_rounds == rounds
@pytest.mark.parametrize('serializer', ["pickle", "joblib", "cloudpickle"]) @pytest.mark.parametrize('serializer', SERIALIZERS)
def test_log_evaluation_callback_is_picklable(serializer, tmp_path): def test_log_evaluation_callback_is_picklable(serializer):
periods = 42 periods = 42
callback = lgb.log_evaluation(period=periods) callback = lgb.log_evaluation(period=periods)
tmp_file = tmp_path / "log_evaluation.pkl" callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
pickle_obj( assert callback_from_disk.order == 10
obj=callback, assert callback_from_disk.before_iteration is False
filepath=tmp_file,
serializer=serializer
)
callback_from_disk = unpickle_obj(
filepath=tmp_file,
serializer=serializer
)
assert callback.period == callback_from_disk.period assert callback.period == callback_from_disk.period
assert callback.period == periods assert callback.period == periods
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_record_evaluation_callback_is_picklable(serializer):
results = {}
callback = lgb.record_evaluation(eval_result=results)
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
assert callback_from_disk.order == 20
assert callback_from_disk.before_iteration is False
assert callback.eval_result == callback_from_disk.eval_result
assert callback.eval_result is results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment