Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
4ae3d138
Unverified
Commit
4ae3d138
authored
Mar 31, 2022
by
Nikita Titov
Committed by
GitHub
Mar 31, 2022
Browse files
[python] make `reset_parameter` callback pickleable (#5109)
parent
3ed0027b
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
19 deletions
+50
-19
python-package/lightgbm/callback.py
python-package/lightgbm/callback.py
+32
-19
tests/python_package_test/test_callback.py
tests/python_package_test/test_callback.py
+18
-0
No files found.
python-package/lightgbm/callback.py
View file @
4ae3d138
...
@@ -130,7 +130,8 @@ class _RecordEvaluationCallback:
...
@@ -130,7 +130,8 @@ class _RecordEvaluationCallback:
self
.
eval_result
[
data_name
][
eval_name
].
append
(
result
)
self
.
eval_result
[
data_name
][
eval_name
].
append
(
result
)
else
:
else
:
data_name
,
eval_name
=
item
[
1
].
split
()
data_name
,
eval_name
=
item
[
1
].
split
()
res_mean
,
res_stdv
=
item
[
2
],
item
[
4
]
res_mean
=
item
[
2
]
res_stdv
=
item
[
4
]
self
.
eval_result
[
data_name
][
f
'
{
eval_name
}
-mean'
].
append
(
res_mean
)
self
.
eval_result
[
data_name
][
f
'
{
eval_name
}
-mean'
].
append
(
res_mean
)
self
.
eval_result
[
data_name
][
f
'
{
eval_name
}
-stdv'
].
append
(
res_stdv
)
self
.
eval_result
[
data_name
][
f
'
{
eval_name
}
-stdv'
].
append
(
res_stdv
)
...
@@ -171,6 +172,34 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
...
@@ -171,6 +172,34 @@ def record_evaluation(eval_result: Dict[str, Dict[str, List[Any]]]) -> Callable:
return
_RecordEvaluationCallback
(
eval_result
=
eval_result
)
return
_RecordEvaluationCallback
(
eval_result
=
eval_result
)
class
_ResetParameterCallback
:
"""Internal reset parameter callable class."""
def
__init__
(
self
,
**
kwargs
:
Union
[
list
,
Callable
])
->
None
:
self
.
order
=
10
self
.
before_iteration
=
True
self
.
kwargs
=
kwargs
def
__call__
(
self
,
env
:
CallbackEnv
)
->
None
:
new_parameters
=
{}
for
key
,
value
in
self
.
kwargs
.
items
():
if
isinstance
(
value
,
list
):
if
len
(
value
)
!=
env
.
end_iteration
-
env
.
begin_iteration
:
raise
ValueError
(
f
"Length of list
{
key
!
r
}
has to be equal to 'num_boost_round'."
)
new_param
=
value
[
env
.
iteration
-
env
.
begin_iteration
]
elif
callable
(
value
):
new_param
=
value
(
env
.
iteration
-
env
.
begin_iteration
)
else
:
raise
ValueError
(
"Only list and callable values are supported "
"as a mapping from boosting round index to new parameter value."
)
if
new_param
!=
env
.
params
.
get
(
key
,
None
):
new_parameters
[
key
]
=
new_param
if
new_parameters
:
env
.
model
.
reset_parameter
(
new_parameters
)
env
.
params
.
update
(
new_parameters
)
def
reset_parameter
(
**
kwargs
:
Union
[
list
,
Callable
])
->
Callable
:
def
reset_parameter
(
**
kwargs
:
Union
[
list
,
Callable
])
->
Callable
:
"""Create a callback that resets the parameter after the first iteration.
"""Create a callback that resets the parameter after the first iteration.
...
@@ -189,26 +218,10 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
...
@@ -189,26 +218,10 @@ def reset_parameter(**kwargs: Union[list, Callable]) -> Callable:
Returns
Returns
-------
-------
callback :
callable
callback :
_ResetParameterCallback
The callback that resets the parameter after the first iteration.
The callback that resets the parameter after the first iteration.
"""
"""
def
_callback
(
env
:
CallbackEnv
)
->
None
:
return
_ResetParameterCallback
(
**
kwargs
)
new_parameters
=
{}
for
key
,
value
in
kwargs
.
items
():
if
isinstance
(
value
,
list
):
if
len
(
value
)
!=
env
.
end_iteration
-
env
.
begin_iteration
:
raise
ValueError
(
f
"Length of list
{
key
!
r
}
has to equal to 'num_boost_round'."
)
new_param
=
value
[
env
.
iteration
-
env
.
begin_iteration
]
else
:
new_param
=
value
(
env
.
iteration
-
env
.
begin_iteration
)
if
new_param
!=
env
.
params
.
get
(
key
,
None
):
new_parameters
[
key
]
=
new_param
if
new_parameters
:
env
.
model
.
reset_parameter
(
new_parameters
)
env
.
params
.
update
(
new_parameters
)
_callback
.
before_iteration
=
True
# type: ignore
_callback
.
order
=
10
# type: ignore
return
_callback
class
_EarlyStoppingCallback
:
class
_EarlyStoppingCallback
:
...
...
tests/python_package_test/test_callback.py
View file @
4ae3d138
...
@@ -22,6 +22,10 @@ def pickle_and_unpickle_object(obj, serializer):
...
@@ -22,6 +22,10 @@ def pickle_and_unpickle_object(obj, serializer):
return
obj_from_disk
return
obj_from_disk
def
reset_feature_fraction
(
boosting_round
):
return
0.6
if
boosting_round
<
15
else
0.8
@
pytest
.
mark
.
parametrize
(
'serializer'
,
SERIALIZERS
)
@
pytest
.
mark
.
parametrize
(
'serializer'
,
SERIALIZERS
)
def
test_early_stopping_callback_is_picklable
(
serializer
):
def
test_early_stopping_callback_is_picklable
(
serializer
):
rounds
=
5
rounds
=
5
...
@@ -53,3 +57,17 @@ def test_record_evaluation_callback_is_picklable(serializer):
...
@@ -53,3 +57,17 @@ def test_record_evaluation_callback_is_picklable(serializer):
assert
callback_from_disk
.
before_iteration
is
False
assert
callback_from_disk
.
before_iteration
is
False
assert
callback
.
eval_result
==
callback_from_disk
.
eval_result
assert
callback
.
eval_result
==
callback_from_disk
.
eval_result
assert
callback
.
eval_result
is
results
assert
callback
.
eval_result
is
results
@
pytest
.
mark
.
parametrize
(
'serializer'
,
SERIALIZERS
)
def
test_reset_parameter_callback_is_picklable
(
serializer
):
params
=
{
'bagging_fraction'
:
[
0.7
]
*
5
+
[
0.6
]
*
5
,
'feature_fraction'
:
reset_feature_fraction
}
callback
=
lgb
.
reset_parameter
(
**
params
)
callback_from_disk
=
pickle_and_unpickle_object
(
obj
=
callback
,
serializer
=
serializer
)
assert
callback_from_disk
.
order
==
10
assert
callback_from_disk
.
before_iteration
is
True
assert
callback
.
kwargs
==
callback_from_disk
.
kwargs
assert
callback
.
kwargs
==
params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment