"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "57ad0149f9a4342298d95bf0b203115ddf84d7e1"
Unverified Commit 4ae3d138 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] make `reset_parameter` callback pickleable (#5109)

parent 3ed0027b
......@@ -130,7 +130,8 @@ class _RecordEvaluationCallback:
self.eval_result[data_name][eval_name].append(result)
else:
data_name, eval_name = item[1].split()
res_mean, res_stdv = item[2], item[4]
res_mean = item[2]
res_stdv = item[4]
self.eval_result[data_name][f'{eval_name}-mean'].append(res_mean)
self.eval_result[data_name][f'{eval_name}-stdv'].append(res_stdv)
......@@ -171,6 +172,34 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
return _RecordEvaluationCallback(eval_result=eval_result)
class _ResetParameterCallback:
"""Internal reset parameter callable class."""
def __init__(self, **kwargs: Union[list, Callable]) -> None:
self.order = 10
self.before_iteration = True
self.kwargs = kwargs
def __call__(self, env: CallbackEnv) -> None:
new_parameters = {}
for key, value in self.kwargs.items():
if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration:
raise ValueError(f"Length of list {key!r} has to be equal to 'num_boost_round'.")
new_param = value[env.iteration - env.begin_iteration]
elif callable(value):
new_param = value(env.iteration - env.begin_iteration)
else:
raise ValueError("Only list and callable values are supported "
"as a mapping from boosting round index to new parameter value.")
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
env.params.update(new_parameters)
def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
"""Create a callback that resets the parameter after the first iteration.
......@@ -189,26 +218,10 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
Returns
-------
callback : callable
callback : _ResetParameterCallback
The callback that resets the parameter after the first iteration.
"""
def _callback(env: CallbackEnv) -> None:
new_parameters = {}
for key, value in kwargs.items():
if isinstance(value, list):
if len(value) != env.end_iteration - env.begin_iteration:
raise ValueError(f"Length of list {key!r} has to equal to 'num_boost_round'.")
new_param = value[env.iteration - env.begin_iteration]
else:
new_param = value(env.iteration - env.begin_iteration)
if new_param != env.params.get(key, None):
new_parameters[key] = new_param
if new_parameters:
env.model.reset_parameter(new_parameters)
env.params.update(new_parameters)
_callback.before_iteration = True # type: ignore
_callback.order = 10 # type: ignore
return _callback
return _ResetParameterCallback(**kwargs)
class _EarlyStoppingCallback:
......
......@@ -22,6 +22,10 @@ def pickle_and_unpickle_object(obj, serializer):
return obj_from_disk
def reset_feature_fraction(boosting_round):
return 0.6 if boosting_round < 15 else 0.8
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_early_stopping_callback_is_picklable(serializer):
rounds = 5
......@@ -53,3 +57,17 @@ def test_record_evaluation_callback_is_picklable(serializer):
assert callback_from_disk.before_iteration is False
assert callback.eval_result == callback_from_disk.eval_result
assert callback.eval_result is results
@pytest.mark.parametrize('serializer', SERIALIZERS)
def test_reset_parameter_callback_is_picklable(serializer):
params = {
'bagging_fraction': [0.7] * 5 + [0.6] * 5,
'feature_fraction': reset_feature_fraction
}
callback = lgb.reset_parameter(**params)
callback_from_disk = pickle_and_unpickle_object(obj=callback, serializer=serializer)
assert callback_from_disk.order == 10
assert callback_from_disk.before_iteration is True
assert callback.kwargs == callback_from_disk.kwargs
assert callback.kwargs == params
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment