Commit b59a5a4c authored by Guolin Ke's avatar Guolin Ke
Browse files

test for early_stopping

parent f8267a50
...@@ -148,8 +148,11 @@ def early_stop(stopping_rounds, verbose=True): ...@@ -148,8 +148,11 @@ def early_stop(stopping_rounds, verbose=True):
callback : function callback : function
The requested callback function. The requested callback function.
""" """
is_init = False state = {}
final_best_iter = 0 factor_to_bigger_better = {}
best_score = {}
best_iter = {}
best_msg = {}
def init(env): def init(env):
"""internal function""" """internal function"""
bst = env.model bst = env.model
...@@ -160,19 +163,20 @@ def early_stop(stopping_rounds, verbose=True): ...@@ -160,19 +163,20 @@ def early_stop(stopping_rounds, verbose=True):
if verbose: if verbose:
msg = "Will train until hasn't improved in {} rounds.\n" msg = "Will train until hasn't improved in {} rounds.\n"
print(msg.format(stopping_rounds)) print(msg.format(stopping_rounds))
best_scores = [ float('-inf') for _ in range(len(env.evaluation_result_list))]
best_iter = [ 0 for _ in range(len(env.evaluation_result_list))]
if verbose:
best_msg = [ "" for _ in range(len(env.evaluation_result_list))]
factor_to_bigger_better = [-1.0 for _ in range(len(env.evaluation_result_list))]
for i in range(len(env.evaluation_result_list)): for i in range(len(env.evaluation_result_list)):
if evaluation.evaluation_result_list[i][3]: best_score[i] = float('-inf')
best_iter[i] = 0
if verbose:
best_msg[i] = ""
factor_to_bigger_better[i] = -1.0
if env.evaluation_result_list[i][3]:
factor_to_bigger_better[i] = 1.0 factor_to_bigger_better[i] = 1.0
is_init = True state['best_iter'] = 0
def callback(env): def callback(env):
"""internal function""" """internal function"""
if not is_init: if len(best_score) == 0:
init(env) init(env)
for i in range(len(env.evaluation_result_list)): for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i] score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i]
...@@ -184,7 +188,7 @@ def early_stop(stopping_rounds, verbose=True): ...@@ -184,7 +188,7 @@ def early_stop(stopping_rounds, verbose=True):
'\t'.join([_format_eval_result(x) for x in env.evaluation_result_list])) '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
else: else:
if env.iteration - best_iter[i] >= stopping_rounds: if env.iteration - best_iter[i] >= stopping_rounds:
final_best_iter = best_iter[i] state['best_iter'] = best_iter[i]
if env.model is not None: if env.model is not None:
env.model.set_attr(best_iteration=str(best_iter[i])) env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose: if verbose:
......
...@@ -112,7 +112,7 @@ def train(params, train_data, num_boost_round=100, ...@@ -112,7 +112,7 @@ def train(params, train_data, num_boost_round=100,
if is_str(init_model): if is_str(init_model):
predictor = Predictor(model_file=init_model) predictor = Predictor(model_file=init_model)
elif isinstance(init_model, Booster): elif isinstance(init_model, Booster):
predictor = Booster.to_predictor() predictor = init_model.to_predictor()
elif isinstance(init_model, Predictor): elif isinstance(init_model, Predictor):
predictor = init_model predictor = init_model
else: else:
...@@ -409,6 +409,6 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False, ...@@ -409,6 +409,6 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
evaluation_result_list=res)) evaluation_result_list=res))
except callback.EarlyStopException as e: except callback.EarlyStopException as e:
for k in results.keys(): for k in results.keys():
results[k] = results[k][:(e.final_best_iter + 1)] results[k] = results[k][:(e.state['best_iter'] + 1)]
break break
return results return results
...@@ -96,8 +96,26 @@ def test_binary_classification_with_custom_objective(): ...@@ -96,8 +96,26 @@ def test_binary_classification_with_custom_objective():
if int(preds[i] > 0.5) != y_test[i]) / float(len(preds)) if int(preds[i] > 0.5) != y_test[i]) / float(len(preds))
assert err < 0.1 assert err < 0.1
def test_early_stopping():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
from sklearn.cross_validation import KFold
from sklearn import datasets, metrics, model_selection
boston = load_boston()
y = boston['target']
X = boston['data']
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1)
lgb_model = lgb.LGBMRegressor(n_estimators=500) \
.fit(x_train, y_train, eval_set=[(x_test, y_test)],
eval_metric='l2',
early_stopping_rounds=10,
verbose=10)
print(lgb_model.best_iteration)
test_binary_classification() test_binary_classification()
test_multiclass_classification() test_multiclass_classification()
test_regression() test_regression()
test_regression_with_custom_objective() test_regression_with_custom_objective()
test_binary_classification_with_custom_objective() test_binary_classification_with_custom_objective()
\ No newline at end of file test_early_stopping()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment