Commit adb8fb4e authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

refine callback (#223)

parent 8c2341b3
......@@ -3,6 +3,7 @@
from __future__ import absolute_import
import collections
from operator import gt, lt
from .compat import range_
......@@ -159,48 +160,50 @@ def early_stopping(stopping_rounds, verbose=True):
callback : function
The requested callback function.
"""
factor_to_bigger_better = {}
best_score = {}
best_iter = {}
best_msg = {}
best_score = []
best_iter = []
best_msg = []
cmp_op = []
def init(env):
"""internal function"""
if not env.evaluation_result_list:
raise ValueError('For early stopping, at least one dataset or eval metric is required for evaluation')
raise ValueError('For early stopping, at least one dataset and eval metric is required for evaluation')
if verbose:
msg = "Train until valid scores didn't improve in {} rounds."
print(msg.format(stopping_rounds))
for i in range_(len(env.evaluation_result_list)):
best_score[i] = float('-inf')
best_iter[i] = 0
for eval_ret in env.evaluation_result_list:
best_iter.append(0)
if verbose:
best_msg[i] = ""
factor_to_bigger_better[i] = 1.0 if env.evaluation_result_list[i][3] else -1.0
best_msg.append(None)
if eval_ret[3]:
best_score.append(float('-inf'))
cmp_op.append(gt)
else:
best_score.append(float('inf'))
cmp_op.append(lt)
def callback(env):
"""internal function"""
if not best_score:
if not cmp_op:
init(env)
best_msg_buffer = None
for i in range_(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i]
if score > best_score[i]:
score = env.evaluation_result_list[i][2]
if cmp_op[i](score, best_score[i]):
best_score[i] = score
best_iter[i] = env.iteration
if verbose:
best_msg[i] = '[%d]\t%s' % (
env.iteration + 1, '\t'.join(
[_format_eval_result(x) for x in env.evaluation_result_list]
)
)
else:
if env.iteration - best_iter[i] >= stopping_rounds:
env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose:
print('Early stopping, best iteration is:')
print(best_msg[i])
raise EarlyStopException(best_iter[i])
if not best_msg_buffer:
best_msg_buffer = '[%d]\t%s' % (
env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
best_msg[i] = best_msg_buffer
elif env.iteration - best_iter[i] >= stopping_rounds:
env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose:
print('Early stopping, best iteration is:\n' + best_msg[i])
raise EarlyStopException(best_iter[i])
callback.order = 30
return callback
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment