Commit b59a5a4c authored by Guolin Ke's avatar Guolin Ke
Browse files

test for early_stopping

parent f8267a50
......@@ -148,8 +148,11 @@ def early_stop(stopping_rounds, verbose=True):
callback : function
The requested callback function.
"""
is_init = False
final_best_iter = 0
state = {}
factor_to_bigger_better = {}
best_score = {}
best_iter = {}
best_msg = {}
def init(env):
"""internal function"""
bst = env.model
......@@ -160,19 +163,20 @@ def early_stop(stopping_rounds, verbose=True):
if verbose:
msg = "Will train until hasn't improved in {} rounds.\n"
print(msg.format(stopping_rounds))
best_scores = [ float('-inf') for _ in range(len(env.evaluation_result_list))]
best_iter = [ 0 for _ in range(len(env.evaluation_result_list))]
if verbose:
best_msg = [ "" for _ in range(len(env.evaluation_result_list))]
factor_to_bigger_better = [-1.0 for _ in range(len(env.evaluation_result_list))]
for i in range(len(env.evaluation_result_list)):
if evaluation.evaluation_result_list[i][3]:
best_score[i] = float('-inf')
best_iter[i] = 0
if verbose:
best_msg[i] = ""
factor_to_bigger_better[i] = -1.0
if env.evaluation_result_list[i][3]:
factor_to_bigger_better[i] = 1.0
is_init = True
state['best_iter'] = 0
def callback(env):
"""internal function"""
if not is_init:
if len(best_score) == 0:
init(env)
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i]
......@@ -184,7 +188,7 @@ def early_stop(stopping_rounds, verbose=True):
'\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
else:
if env.iteration - best_iter[i] >= stopping_rounds:
final_best_iter = best_iter[i]
state['best_iter'] = best_iter[i]
if env.model is not None:
env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose:
......
......@@ -112,7 +112,7 @@ def train(params, train_data, num_boost_round=100,
if is_str(init_model):
predictor = Predictor(model_file=init_model)
elif isinstance(init_model, Booster):
predictor = Booster.to_predictor()
predictor = init_model.to_predictor()
elif isinstance(init_model, Predictor):
predictor = init_model
else:
......@@ -409,6 +409,6 @@ def cv(params, train_data, num_boost_round=10, nfold=5, stratified=False,
evaluation_result_list=res))
except callback.EarlyStopException as e:
for k in results.keys():
results[k] = results[k][:(e.final_best_iter + 1)]
results[k] = results[k][:(e.state['best_iter'] + 1)]
break
return results
......@@ -96,8 +96,26 @@ def test_binary_classification_with_custom_objective():
if int(preds[i] > 0.5) != y_test[i]) / float(len(preds))
assert err < 0.1
def test_early_stopping():
from sklearn.metrics import mean_squared_error
from sklearn.datasets import load_boston
from sklearn.cross_validation import KFold
from sklearn import datasets, metrics, model_selection
boston = load_boston()
y = boston['target']
X = boston['data']
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.1, random_state=1)
lgb_model = lgb.LGBMRegressor(n_estimators=500) \
.fit(x_train, y_train, eval_set=[(x_test, y_test)],
eval_metric='l2',
early_stopping_rounds=10,
verbose=10)
print(lgb_model.best_iteration)
test_binary_classification()
test_multiclass_classification()
test_regression()
test_regression_with_custom_objective()
test_binary_classification_with_custom_objective()
\ No newline at end of file
test_binary_classification_with_custom_objective()
test_early_stopping()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment