Commit adb8fb4e authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

refine callback (#223)

parent 8c2341b3
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import from __future__ import absolute_import
import collections import collections
from operator import gt, lt
from .compat import range_ from .compat import range_
...@@ -159,48 +160,50 @@ def early_stopping(stopping_rounds, verbose=True): ...@@ -159,48 +160,50 @@ def early_stopping(stopping_rounds, verbose=True):
callback : function callback : function
The requested callback function. The requested callback function.
""" """
factor_to_bigger_better = {} best_score = []
best_score = {} best_iter = []
best_iter = {} best_msg = []
best_msg = {} cmp_op = []
def init(env): def init(env):
"""internal function""" """internal function"""
if not env.evaluation_result_list: if not env.evaluation_result_list:
raise ValueError('For early stopping, at least one dataset or eval metric is required for evaluation') raise ValueError('For early stopping, at least one dataset and eval metric is required for evaluation')
if verbose: if verbose:
msg = "Train until valid scores didn't improve in {} rounds." msg = "Train until valid scores didn't improve in {} rounds."
print(msg.format(stopping_rounds)) print(msg.format(stopping_rounds))
for i in range_(len(env.evaluation_result_list)): for eval_ret in env.evaluation_result_list:
best_score[i] = float('-inf') best_iter.append(0)
best_iter[i] = 0
if verbose: if verbose:
best_msg[i] = "" best_msg.append(None)
factor_to_bigger_better[i] = 1.0 if env.evaluation_result_list[i][3] else -1.0 if eval_ret[3]:
best_score.append(float('-inf'))
cmp_op.append(gt)
else:
best_score.append(float('inf'))
cmp_op.append(lt)
def callback(env): def callback(env):
"""internal function""" """internal function"""
if not best_score: if not cmp_op:
init(env) init(env)
best_msg_buffer = None
for i in range_(len(env.evaluation_result_list)): for i in range_(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i] score = env.evaluation_result_list[i][2]
if score > best_score[i]: if cmp_op[i](score, best_score[i]):
best_score[i] = score best_score[i] = score
best_iter[i] = env.iteration best_iter[i] = env.iteration
if verbose: if verbose:
best_msg[i] = '[%d]\t%s' % ( if not best_msg_buffer:
env.iteration + 1, '\t'.join( best_msg_buffer = '[%d]\t%s' % (
[_format_eval_result(x) for x in env.evaluation_result_list] env.iteration + 1, '\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
) best_msg[i] = best_msg_buffer
) elif env.iteration - best_iter[i] >= stopping_rounds:
else:
if env.iteration - best_iter[i] >= stopping_rounds:
env.model.set_attr(best_iteration=str(best_iter[i])) env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose: if verbose:
print('Early stopping, best iteration is:') print('Early stopping, best iteration is:\n' + best_msg[i])
print(best_msg[i])
raise EarlyStopException(best_iter[i]) raise EarlyStopException(best_iter[i])
callback.order = 30 callback.order = 30
return callback return callback
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment