Commit 45c1c6e8 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

add best score (#413)

parent 9224a9d1
...@@ -1179,6 +1179,7 @@ class Booster(object): ...@@ -1179,6 +1179,7 @@ class Booster(object):
self.__train_data_name = "training" self.__train_data_name = "training"
self.__attr = {} self.__attr = {}
self.best_iteration = -1 self.best_iteration = -1
self.best_score = {}
params = {} if params is None else params params = {} if params is None else params
if silent: if silent:
params["verbose"] = 0 params["verbose"] = 0
......
...@@ -15,9 +15,10 @@ class EarlyStopException(Exception): ...@@ -15,9 +15,10 @@ class EarlyStopException(Exception):
best_iteration : int best_iteration : int
The best iteration stopped. The best iteration stopped.
""" """
def __init__(self, best_iteration): def __init__(self, best_iteration, best_score):
super(EarlyStopException, self).__init__() super(EarlyStopException, self).__init__()
self.best_iteration = best_iteration self.best_iteration = best_iteration
self.best_score = best_score
# Callback environment used by callbacks # Callback environment used by callbacks
...@@ -162,7 +163,7 @@ def early_stopping(stopping_rounds, verbose=True): ...@@ -162,7 +163,7 @@ def early_stopping(stopping_rounds, verbose=True):
""" """
best_score = [] best_score = []
best_iter = [] best_iter = []
best_msg = [] best_score_list = []
cmp_op = [] cmp_op = []
def init(env): def init(env):
...@@ -176,8 +177,7 @@ def early_stopping(stopping_rounds, verbose=True): ...@@ -176,8 +177,7 @@ def early_stopping(stopping_rounds, verbose=True):
for eval_ret in env.evaluation_result_list: for eval_ret in env.evaluation_result_list:
best_iter.append(0) best_iter.append(0)
if verbose: best_score_list.append(None)
best_msg.append(None)
if eval_ret[3]: if eval_ret[3]:
best_score.append(float('-inf')) best_score.append(float('-inf'))
cmp_op.append(gt) cmp_op.append(gt)
...@@ -189,20 +189,16 @@ def early_stopping(stopping_rounds, verbose=True): ...@@ -189,20 +189,16 @@ def early_stopping(stopping_rounds, verbose=True):
"""internal function""" """internal function"""
if not cmp_op: if not cmp_op:
init(env) init(env)
best_msg_buffer = None
for i in range_(len(env.evaluation_result_list)): for i in range_(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] score = env.evaluation_result_list[i][2]
if cmp_op[i](score, best_score[i]): if cmp_op[i](score, best_score[i]):
best_score[i] = score best_score[i] = score
best_iter[i] = env.iteration best_iter[i] = env.iteration
if verbose: best_score_list[i] = env.evaluation_result_list
if not best_msg_buffer:
best_msg_buffer = '[%d]\t%s' % (
env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
best_msg[i] = best_msg_buffer
elif env.iteration - best_iter[i] >= stopping_rounds: elif env.iteration - best_iter[i] >= stopping_rounds:
if verbose: if verbose:
print('Early stopping, best iteration is:\n' + best_msg[i]) print('Early stopping, best iteration is:\n[%d]\t%s' % (
raise EarlyStopException(best_iter[i]) best_iter[i] + 1, '\t'.join([_format_eval_result(x) for x in best_score_list[i]])))
raise EarlyStopException(best_iter[i], best_score_list[i])
callback.order = 30 callback.order = 30
return callback return callback
...@@ -195,7 +195,11 @@ def train(params, train_set, num_boost_round=100, ...@@ -195,7 +195,11 @@ def train(params, train_set, num_boost_round=100,
evaluation_result_list=evaluation_result_list)) evaluation_result_list=evaluation_result_list))
except callback.EarlyStopException as earlyStopException: except callback.EarlyStopException as earlyStopException:
booster.best_iteration = earlyStopException.best_iteration + 1 booster.best_iteration = earlyStopException.best_iteration + 1
evaluation_result_list = earlyStopException.best_score
break break
booster.best_score = collections.defaultdict(dict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
return booster return booster
......
...@@ -273,6 +273,7 @@ class LGBMModel(LGBMModelBase): ...@@ -273,6 +273,7 @@ class LGBMModel(LGBMModelBase):
self._Booster = None self._Booster = None
self.evals_result = None self.evals_result = None
self.best_iteration = -1 self.best_iteration = -1
self.best_score = {}
if callable(self.objective): if callable(self.objective):
self.fobj = _objective_function_wrapper(self.objective) self.fobj = _objective_function_wrapper(self.objective)
else: else:
...@@ -414,6 +415,7 @@ class LGBMModel(LGBMModelBase): ...@@ -414,6 +415,7 @@ class LGBMModel(LGBMModelBase):
if early_stopping_rounds is not None: if early_stopping_rounds is not None:
self.best_iteration = self._Booster.best_iteration self.best_iteration = self._Booster.best_iteration
self.best_score = self._Booster.best_score
return self return self
def predict(self, X, raw_score=False, num_iteration=0): def predict(self, X, raw_score=False, num_iteration=0):
......
...@@ -96,20 +96,27 @@ class TestEngine(unittest.TestCase): ...@@ -96,20 +96,27 @@ class TestEngine(unittest.TestCase):
X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42) X_train, X_test, y_train, y_test = train_test_split(*X_y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train) lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = 'valid_set'
# no early stopping # no early stopping
gbm = lgb.train(params, lgb_train, gbm = lgb.train(params, lgb_train,
num_boost_round=10, num_boost_round=10,
valid_sets=lgb_eval, valid_sets=lgb_eval,
valid_names=valid_set_name,
verbose_eval=False, verbose_eval=False,
early_stopping_rounds=5) early_stopping_rounds=5)
self.assertEqual(gbm.best_iteration, -1) self.assertEqual(gbm.best_iteration, -1)
self.assertIn(valid_set_name, gbm.best_score)
self.assertIn('binary_logloss', gbm.best_score[valid_set_name])
# early stopping occurs # early stopping occurs
gbm = lgb.train(params, lgb_train, gbm = lgb.train(params, lgb_train,
num_boost_round=100, num_boost_round=100,
valid_sets=lgb_eval, valid_sets=lgb_eval,
valid_names=valid_set_name,
verbose_eval=False, verbose_eval=False,
early_stopping_rounds=5) early_stopping_rounds=5)
self.assertLessEqual(gbm.best_iteration, 100) self.assertLessEqual(gbm.best_iteration, 100)
self.assertIn(valid_set_name, gbm.best_score)
self.assertIn('binary_logloss', gbm.best_score[valid_set_name])
def test_continue_train_and_other(self): def test_continue_train_and_other(self):
params = { params = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment