Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
tianlh
LightGBM-DCU
Commits
b59a5a4c
Commit
b59a5a4c
authored
Nov 30, 2016
by
Guolin Ke
Browse files
test for early_stopping
parent
f8267a50
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
14 deletions
+36
-14
python-package/lightgbm/callback.py
python-package/lightgbm/callback.py
+15
-11
python-package/lightgbm/engine.py
python-package/lightgbm/engine.py
+2
-2
tests/python_package_test/test_sklearn.py
tests/python_package_test/test_sklearn.py
+19
-1
No files found.
python-package/lightgbm/callback.py
View file @
b59a5a4c
...
@@ -148,8 +148,11 @@ def early_stop(stopping_rounds, verbose=True):
...
@@ -148,8 +148,11 @@ def early_stop(stopping_rounds, verbose=True):
callback : function
callback : function
The requested callback function.
The requested callback function.
"""
"""
is_init
=
False
state
=
{}
final_best_iter
=
0
factor_to_bigger_better
=
{}
best_score
=
{}
best_iter
=
{}
best_msg
=
{}
def
init
(
env
):
def
init
(
env
):
"""internal function"""
"""internal function"""
bst
=
env
.
model
bst
=
env
.
model
...
@@ -160,19 +163,20 @@ def early_stop(stopping_rounds, verbose=True):
...
@@ -160,19 +163,20 @@ def early_stop(stopping_rounds, verbose=True):
if
verbose
:
if
verbose
:
msg
=
"Will train until hasn't improved in {} rounds.
\n
"
msg
=
"Will train until hasn't improved in {} rounds.
\n
"
print
(
msg
.
format
(
stopping_rounds
))
print
(
msg
.
format
(
stopping_rounds
))
best_scores
=
[
float
(
'-inf'
)
for
_
in
range
(
len
(
env
.
evaluation_result_list
))]
best_iter
=
[
0
for
_
in
range
(
len
(
env
.
evaluation_result_list
))]
if
verbose
:
best_msg
=
[
""
for
_
in
range
(
len
(
env
.
evaluation_result_list
))]
factor_to_bigger_better
=
[
-
1.0
for
_
in
range
(
len
(
env
.
evaluation_result_list
))]
for
i
in
range
(
len
(
env
.
evaluation_result_list
)):
for
i
in
range
(
len
(
env
.
evaluation_result_list
)):
if
evaluation
.
evaluation_result_list
[
i
][
3
]:
best_score
[
i
]
=
float
(
'-inf'
)
best_iter
[
i
]
=
0
if
verbose
:
best_msg
[
i
]
=
""
factor_to_bigger_better
[
i
]
=
-
1.0
if
env
.
evaluation_result_list
[
i
][
3
]:
factor_to_bigger_better
[
i
]
=
1.0
factor_to_bigger_better
[
i
]
=
1.0
is_init
=
True
state
[
'best_iter'
]
=
0
def
callback
(
env
):
def
callback
(
env
):
"""internal function"""
"""internal function"""
if
not
is_init
:
if
len
(
best_score
)
==
0
:
init
(
env
)
init
(
env
)
for
i
in
range
(
len
(
env
.
evaluation_result_list
)):
for
i
in
range
(
len
(
env
.
evaluation_result_list
)):
score
=
env
.
evaluation_result_list
[
i
][
2
]
*
factor_to_bigger_better
[
i
]
score
=
env
.
evaluation_result_list
[
i
][
2
]
*
factor_to_bigger_better
[
i
]
...
@@ -184,7 +188,7 @@ def early_stop(stopping_rounds, verbose=True):
...
@@ -184,7 +188,7 @@ def early_stop(stopping_rounds, verbose=True):
'
\t
'
.
join
([
_format_eval_result
(
x
)
for
x
in
env
.
evaluation_result_list
]))
'
\t
'
.
join
([
_format_eval_result
(
x
)
for
x
in
env
.
evaluation_result_list
]))
else
:
else
:
if
env
.
iteration
-
best_iter
[
i
]
>=
stopping_rounds
:
if
env
.
iteration
-
best_iter
[
i
]
>=
stopping_rounds
:
final_
best_iter
=
best_iter
[
i
]
state
[
'
best_iter
'
]
=
best_iter
[
i
]
if
env
.
model
is
not
None
:
if
env
.
model
is
not
None
:
env
.
model
.
set_attr
(
best_iteration
=
str
(
best_iter
[
i
]))
env
.
model
.
set_attr
(
best_iteration
=
str
(
best_iter
[
i
]))
if
verbose
:
if
verbose
:
...
...
python-package/lightgbm/engine.py
View file @
b59a5a4c
...
@@ -112,7 +112,7 @@ def train(params, train_data, num_boost_round=100,
...
@@ -112,7 +112,7 @@ def train(params, train_data, num_boost_round=100,
if
is_str
(
init_model
):
if
is_str
(
init_model
):
predictor
=
Predictor
(
model_file
=
init_model
)
predictor
=
Predictor
(
model_file
=
init_model
)
elif
isinstance
(
init_model
,
Booster
):
elif
isinstance
(
init_model
,
Booster
):
predictor
=
Booster
.
to_predictor
()
predictor
=
init_model
.
to_predictor
()
elif
isinstance
(
init_model
,
Predictor
):
elif
isinstance
(
init_model
,
Predictor
):
predictor
=
init_model
predictor
=
init_model
else
:
else
:
...
@@ -409,6 +409,6 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
...
@@ -409,6 +409,6 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
evaluation_result_list
=
res
))
evaluation_result_list
=
res
))
except
callback
.
EarlyStopException
as
e
:
except
callback
.
EarlyStopException
as
e
:
for
k
in
results
.
keys
():
for
k
in
results
.
keys
():
results
[
k
]
=
results
[
k
][:(
e
.
final_
best_iter
+
1
)]
results
[
k
]
=
results
[
k
][:(
e
.
state
[
'
best_iter
'
]
+
1
)]
break
break
return
results
return
results
tests/python_package_test/test_sklearn.py
View file @
b59a5a4c
...
@@ -96,8 +96,26 @@ def test_binary_classification_with_custom_objective():
...
@@ -96,8 +96,26 @@ def test_binary_classification_with_custom_objective():
if
int
(
preds
[
i
]
>
0.5
)
!=
y_test
[
i
])
/
float
(
len
(
preds
))
if
int
(
preds
[
i
]
>
0.5
)
!=
y_test
[
i
])
/
float
(
len
(
preds
))
assert
err
<
0.1
assert
err
<
0.1
def
test_early_stopping
():
from
sklearn.metrics
import
mean_squared_error
from
sklearn.datasets
import
load_boston
from
sklearn.cross_validation
import
KFold
from
sklearn
import
datasets
,
metrics
,
model_selection
boston
=
load_boston
()
y
=
boston
[
'target'
]
X
=
boston
[
'data'
]
x_train
,
x_test
,
y_train
,
y_test
=
model_selection
.
train_test_split
(
X
,
y
,
test_size
=
0.1
,
random_state
=
1
)
lgb_model
=
lgb
.
LGBMRegressor
(
n_estimators
=
500
)
\
.
fit
(
x_train
,
y_train
,
eval_set
=
[(
x_test
,
y_test
)],
eval_metric
=
'l2'
,
early_stopping_rounds
=
10
,
verbose
=
10
)
print
(
lgb_model
.
best_iteration
)
test_binary_classification
()
test_binary_classification
()
test_multiclass_classification
()
test_multiclass_classification
()
test_regression
()
test_regression
()
test_regression_with_custom_objective
()
test_regression_with_custom_objective
()
test_binary_classification_with_custom_objective
()
test_binary_classification_with_custom_objective
()
test_early_stopping
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment