Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
d45dca70
"src/vscode:/vscode.git/clone" did not exist on "a97c444b4cf9d2755bd888911ce65ace1fe13e4b"
Unverified
Commit
d45dca70
authored
Oct 05, 2023
by
James Lamb
Committed by
GitHub
Oct 05, 2023
Browse files
[python-package] reorganize early stopping callback (#6114)
parent
f175cebd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
58 additions
and
21 deletions
+58
-21
python-package/lightgbm/callback.py
python-package/lightgbm/callback.py
+45
-19
tests/python_package_test/test_callback.py
tests/python_package_test/test_callback.py
+11
-0
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+2
-2
No files found.
python-package/lightgbm/callback.py
View file @
d45dca70
...
...
@@ -229,7 +229,12 @@ class _ResetParameterCallback:
if
new_param
!=
env
.
params
.
get
(
key
,
None
):
new_parameters
[
key
]
=
new_param
if
new_parameters
:
if
isinstance
(
env
.
model
,
Booster
):
env
.
model
.
reset_parameter
(
new_parameters
)
else
:
# CVBooster holds a list of Booster objects, each needs to be updated
for
booster
in
env
.
model
.
boosters
:
booster
.
reset_parameter
(
new_parameters
)
env
.
params
.
update
(
new_parameters
)
...
...
@@ -267,6 +272,10 @@ class _EarlyStoppingCallback:
verbose
:
bool
=
True
,
min_delta
:
Union
[
float
,
List
[
float
]]
=
0.0
)
->
None
:
if
not
isinstance
(
stopping_rounds
,
int
)
or
stopping_rounds
<=
0
:
raise
ValueError
(
f
"stopping_rounds should be an integer and greater than 0. got:
{
stopping_rounds
}
"
)
self
.
order
=
30
self
.
before_iteration
=
False
...
...
@@ -291,33 +300,46 @@ class _EarlyStoppingCallback:
def
_lt_delta
(
self
,
curr_score
:
float
,
best_score
:
float
,
delta
:
float
)
->
bool
:
return
curr_score
<
best_score
-
delta
def
_is_train_set
(
self
,
ds_name
:
str
,
eval_name
:
str
,
train_name
:
str
)
->
bool
:
return
(
ds_name
==
"cv_agg"
and
eval_name
==
"train"
)
or
ds_name
==
train_name
def
_is_train_set
(
self
,
ds_name
:
str
,
eval_name
:
str
,
env
:
CallbackEnv
)
->
bool
:
"""Check, by name, if a given Dataset is the training data."""
# for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set
# and those metrics are considered for early stopping
if
ds_name
==
"cv_agg"
and
eval_name
==
"train"
:
return
True
# for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name
if
isinstance
(
env
.
model
,
Booster
)
and
ds_name
==
env
.
model
.
_train_data_name
:
return
True
return
False
def
_init
(
self
,
env
:
CallbackEnv
)
->
None
:
if
env
.
evaluation_result_list
is
None
or
env
.
evaluation_result_list
==
[]:
raise
ValueError
(
"For early stopping, at least one dataset and eval metric is required for evaluation"
)
is_dart
=
any
(
env
.
params
.
get
(
alias
,
""
)
==
'dart'
for
alias
in
_ConfigAliases
.
get
(
"boosting"
))
if
is_dart
:
self
.
enabled
=
False
_log_warning
(
'Early stopping is not available in dart mode'
)
return
# validation sets are guaranteed to not be identical to the training data in cv()
if
isinstance
(
env
.
model
,
Booster
):
only_train_set
=
(
len
(
env
.
evaluation_result_list
)
==
1
and
self
.
_is_train_set
(
ds_name
=
env
.
evaluation_result_list
[
0
][
0
],
eval_name
=
env
.
evaluation_result_list
[
0
][
1
].
split
(
" "
)[
0
],
train_name
=
env
.
model
.
_train_data_name
)
env
=
env
)
self
.
enabled
=
not
is_dart
and
not
only_train_set
if
not
self
.
enabled
:
if
is_dart
:
_log_warning
(
'Early stopping is not available in dart mode'
)
elif
only_train_set
:
)
if
only_train_set
:
self
.
enabled
=
False
_log_warning
(
'Only training set found, disabling early stopping.'
)
return
if
self
.
stopping_rounds
<=
0
:
raise
ValueError
(
"stopping_rounds should be greater than zero."
)
if
self
.
verbose
:
_log_info
(
f
"Training until validation scores don't improve for
{
self
.
stopping_rounds
}
rounds"
)
...
...
@@ -395,7 +417,11 @@ class _EarlyStoppingCallback:
eval_name_splitted
=
env
.
evaluation_result_list
[
i
][
1
].
split
(
" "
)
if
self
.
first_metric_only
and
self
.
first_metric
!=
eval_name_splitted
[
-
1
]:
continue
# use only the first metric for early stopping
if
self
.
_is_train_set
(
env
.
evaluation_result_list
[
i
][
0
],
eval_name_splitted
[
0
],
env
.
model
.
_train_data_name
):
if
self
.
_is_train_set
(
ds_name
=
env
.
evaluation_result_list
[
i
][
0
],
eval_name
=
eval_name_splitted
[
0
],
env
=
env
):
continue
# train data for lgb.cv or sklearn wrapper (underlying lgb.train)
elif
env
.
iteration
-
self
.
best_iter
[
i
]
>=
self
.
stopping_rounds
:
if
self
.
verbose
:
...
...
tests/python_package_test/test_callback.py
View file @
d45dca70
...
...
@@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer):
assert
callback
.
stopping_rounds
==
rounds
def
test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors
():
with
pytest
.
raises
(
ValueError
,
match
=
"stopping_rounds should be an integer and greater than 0. got: 0"
):
lgb
.
early_stopping
(
stopping_rounds
=
0
)
with
pytest
.
raises
(
ValueError
,
match
=
"stopping_rounds should be an integer and greater than 0. got: -1"
):
lgb
.
early_stopping
(
stopping_rounds
=-
1
)
with
pytest
.
raises
(
ValueError
,
match
=
"stopping_rounds should be an integer and greater than 0. got: neverrrr"
):
lgb
.
early_stopping
(
stopping_rounds
=
"neverrrr"
)
@
pytest
.
mark
.
parametrize
(
'serializer'
,
SERIALIZERS
)
def
test_log_evaluation_callback_is_picklable
(
serializer
):
periods
=
42
...
...
tests/python_package_test/test_engine.py
View file @
d45dca70
...
...
@@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object
def
test_train_raises_informative_error_for_params_of_wrong_type
():
X
,
y
=
make_synthetic_regression
()
params
=
{
"
early_stopping_round
"
:
"too-many"
}
params
=
{
"
num_leaves
"
:
"too-many"
}
dtrain
=
lgb
.
Dataset
(
X
,
label
=
y
)
with
pytest
.
raises
(
lgb
.
basic
.
LightGBMError
,
match
=
"Parameter
early_stopping_round
should be of type int, got
\"
too-many
\"
"
):
with
pytest
.
raises
(
lgb
.
basic
.
LightGBMError
,
match
=
"Parameter
num_leaves
should be of type int, got
\"
too-many
\"
"
):
lgb
.
train
(
params
,
dtrain
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment