Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
6ed335df
Commit
6ed335df
authored
Mar 28, 2017
by
wxchan
Committed by
Guolin Ke
Mar 28, 2017
Browse files
refine early stopping and add a test case (#369)
parent
1141ed9d
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
8 deletions
+33
-8
python-package/lightgbm/callback.py
python-package/lightgbm/callback.py
+0
-1
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+7
-7
tests/python_package_test/test_engine.py
tests/python_package_test/test_engine.py
+26
-0
No files found.
python-package/lightgbm/callback.py
View file @
6ed335df
...
...
@@ -201,7 +201,6 @@ def early_stopping(stopping_rounds, verbose=True):
env
.
iteration
+
1
,
'
\t
'
.
join
([
_format_eval_result
(
x
)
for
x
in
env
.
evaluation_result_list
]))
best_msg
[
i
]
=
best_msg_buffer
elif
env
.
iteration
-
best_iter
[
i
]
>=
stopping_rounds
:
env
.
model
.
set_attr
(
best_iteration
=
str
(
best_iter
[
i
]))
if
verbose
:
print
(
'Early stopping, best iteration is:
\n
'
+
best_msg
[
i
])
raise
EarlyStopException
(
best_iter
[
i
])
...
...
python-package/lightgbm/engine.py
View file @
6ed335df
...
...
@@ -165,6 +165,7 @@ def train(params, train_set, num_boost_round=100,
booster
.
set_train_data_name
(
train_data_name
)
for
valid_set
,
name_valid_set
in
zip
(
reduced_valid_sets
,
name_valid_sets
):
booster
.
add_valid
(
valid_set
,
name_valid_set
)
booster
.
best_iteration
=
-
1
"""start training"""
for
i
in
range_
(
init_iteration
,
init_iteration
+
num_boost_round
):
...
...
@@ -192,12 +193,9 @@ def train(params, train_set, num_boost_round=100,
begin_iteration
=
init_iteration
,
end_iteration
=
init_iteration
+
num_boost_round
,
evaluation_result_list
=
evaluation_result_list
))
except
callback
.
EarlyStopException
:
except
callback
.
EarlyStopException
as
earlyStopException
:
booster
.
best_iteration
=
earlyStopException
.
best_iteration
+
1
break
if
booster
.
attr
(
'best_iteration'
)
is
not
None
:
booster
.
best_iteration
=
int
(
booster
.
attr
(
'best_iteration'
))
+
1
else
:
booster
.
best_iteration
=
-
1
return
booster
...
...
@@ -205,6 +203,7 @@ class CVBooster(object):
""""Auxiliary data struct to hold all boosters of CV."""
def
__init__
(
self
):
self
.
boosters
=
[]
self
.
best_iteration
=
-
1
def
append
(
self
,
booster
):
"""add a booster to CVBooster"""
...
...
@@ -408,8 +407,9 @@ def cv(params, train_set, num_boost_round=10,
begin_iteration
=
0
,
end_iteration
=
num_boost_round
,
evaluation_result_list
=
res
))
except
callback
.
EarlyStopException
as
e
:
except
callback
.
EarlyStopException
as
earlyStopException
:
cvfolds
.
best_iteration
=
earlyStopException
.
best_iteration
+
1
for
k
in
results
:
results
[
k
]
=
results
[
k
][:
e
.
best_iteration
+
1
]
results
[
k
]
=
results
[
k
][:
cvfolds
.
best_iteration
]
break
return
dict
(
results
)
tests/python_package_test/test_engine.py
View file @
6ed335df
...
...
@@ -86,6 +86,32 @@ class TestEngine(unittest.TestCase):
self
.
assertLess
(
ret
,
0.2
)
self
.
assertAlmostEqual
(
min
(
evals_result
[
'eval'
][
'multi_logloss'
]),
ret
,
places
=
5
)
def
test_early_stopping
(
self
):
X_y
=
load_breast_cancer
(
True
)
params
=
{
'objective'
:
'binary'
,
'metric'
:
'binary_logloss'
,
'verbose'
:
-
1
,
'seed'
:
42
}
X_train
,
X_test
,
y_train
,
y_test
=
train_test_split
(
*
X_y
,
test_size
=
0.1
,
random_state
=
42
)
lgb_train
=
lgb
.
Dataset
(
X_train
,
y_train
)
lgb_eval
=
lgb
.
Dataset
(
X_test
,
y_test
,
reference
=
lgb_train
)
# no early stopping
gbm
=
lgb
.
train
(
params
,
lgb_train
,
num_boost_round
=
10
,
valid_sets
=
lgb_eval
,
verbose_eval
=
False
,
early_stopping_rounds
=
5
)
self
.
assertEqual
(
gbm
.
best_iteration
,
-
1
)
# early stopping occurs
gbm
=
lgb
.
train
(
params
,
lgb_train
,
num_boost_round
=
100
,
valid_sets
=
lgb_eval
,
verbose_eval
=
False
,
early_stopping_rounds
=
5
)
self
.
assertLessEqual
(
gbm
.
best_iteration
,
100
)
def
test_continue_train_and_other
(
self
):
params
=
{
'objective'
:
'regression'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment