Commit 27624755 authored by Guolin Ke's avatar Guolin Ke
Browse files

add main training logic and callbacks

parent 63eddae0
...@@ -787,7 +787,6 @@ class Dataset(object): ...@@ -787,7 +787,6 @@ class Dataset(object):
ctypes.byref(ret))) ctypes.byref(ret)))
return ret.value return ret.value
class Booster(object): class Booster(object):
""""A Booster of of LightGBM. """"A Booster of of LightGBM.
""" """
...@@ -808,6 +807,7 @@ class Booster(object): ...@@ -808,6 +807,7 @@ class Booster(object):
self.handle = ctypes.c_void_p() self.handle = ctypes.c_void_p()
self.__need_reload_eval_info = True self.__need_reload_eval_info = True
self.__is_manage_handle = True self.__is_manage_handle = True
self.__train_data_name = "training"
params = {} if params is None else params params = {} if params is None else params
if silent: if silent:
params["verbose"] = 0 params["verbose"] = 0
...@@ -861,6 +861,9 @@ class Booster(object): ...@@ -861,6 +861,9 @@ class Booster(object):
if self.handle is not None and self.__is_manage_handle: if self.handle is not None and self.__is_manage_handle:
_safe_call(_LIB.LGBM_BoosterFree(self.handle)) _safe_call(_LIB.LGBM_BoosterFree(self.handle))
def set_train_data_name(self, name):
self.__train_data_name = name
def add_valid(self, data, name): def add_valid(self, data, name):
"""Add an validation data """Add an validation data
...@@ -882,7 +885,7 @@ class Booster(object): ...@@ -882,7 +885,7 @@ class Booster(object):
self.__inner_predict_buffer.append(None) self.__inner_predict_buffer.append(None)
self.__is_predicted_cur_iter.append(False) self.__is_predicted_cur_iter.append(False)
def reset_parameter(self, params, silent=False): def reset_parameter(self, params):
"""Reset parameters for booster """Reset parameters for booster
Parameters Parameters
...@@ -892,11 +895,8 @@ class Booster(object): ...@@ -892,11 +895,8 @@ class Booster(object):
silent : boolean, optional silent : boolean, optional
Whether print messages during construction Whether print messages during construction
""" """
if 'metric' in params:
self.__need_reload_eval_info = True self.__need_reload_eval_info = True
if silent:
params["verbose"] = 0
elif "verbose" not in params:
params["verbose"] = 1
params_str = param_dict_to_str(params) params_str = param_dict_to_str(params)
_safe_call(_LIB.LGBM_BoosterResetParameter( _safe_call(_LIB.LGBM_BoosterResetParameter(
self.handle, self.handle,
...@@ -1040,7 +1040,7 @@ class Booster(object): ...@@ -1040,7 +1040,7 @@ class Booster(object):
result: str result: str
Evaluation result list. Evaluation result list.
""" """
return self.__inner_eval("training", 0, feval) return self.__inner_eval(self.__train_data_name, 0, feval)
def eval_valid(self, feval=None): def eval_valid(self, feval=None):
"""Evaluate for validation data """Evaluate for validation data
...@@ -1129,7 +1129,7 @@ class Booster(object): ...@@ -1129,7 +1129,7 @@ class Booster(object):
if tmp_out_len.value != self.__num_inner_eval: if tmp_out_len.value != self.__num_inner_eval:
raise ValueError("incorrect number of eval results") raise ValueError("incorrect number of eval results")
for i in range(self.__num_inner_eval): for i in range(self.__num_inner_eval):
ret.append((data_name, self.__name_inner_eval[i], result[i])) ret.append((data_name, self.__name_inner_eval[i], result[i], self.__higher_better_inner_eval[i]))
if feval is not None: if feval is not None:
if data_idx == 0: if data_idx == 0:
cur_data = self.train_set cur_data = self.train_set
...@@ -1137,11 +1137,11 @@ class Booster(object): ...@@ -1137,11 +1137,11 @@ class Booster(object):
cur_data = self.valid_sets[data_idx - 1] cur_data = self.valid_sets[data_idx - 1]
feval_ret = feval(self.__inner_predict(data_idx), cur_data) feval_ret = feval(self.__inner_predict(data_idx), cur_data)
if isinstance(feval_ret, list): if isinstance(feval_ret, list):
for eval_name, val in feval_ret: for eval_name, val, is_higher_better in feval_ret:
ret.append((data_name, eval_name, val)) ret.append((data_name, eval_name, val, is_higher_better))
else: else:
eval_name, val = feval_ret eval_name, val, is_higher_better = feval_ret
ret.append((data_name, eval_name, val)) ret.append((data_name, eval_name, val, is_higher_better))
return ret return ret
def __inner_predict(self, data_idx): def __inner_predict(self, data_idx):
...@@ -1197,3 +1197,10 @@ class Booster(object): ...@@ -1197,3 +1197,10 @@ class Booster(object):
self.__name_inner_eval = [] self.__name_inner_eval = []
for i in range(self.__num_inner_eval): for i in range(self.__num_inner_eval):
self.__name_inner_eval.append(string_buffers[i].value.decode()) self.__name_inner_eval.append(string_buffers[i].value.decode())
self.__higher_better_inner_eval = []
higher_better_metric = ['auc', 'ndcg']
for name in self.__name_inner_eval:
if any(name.startswith(x) for x in higher_better_metric):
self.__higher_better_inner_eval.append(True)
else:
self.__higher_better_inner_eval.append(False)
from __future__ import absolute_import
class EarlyStopException(Exception):
"""Exception of early stopping.
Parameters
----------
best_iteration : int
The best iteration stopped.
"""
def __init__(self, best_iteration):
super(EarlyStopException, self).__init__()
self.best_iteration = best_iteration
# Callback environment used by callbacks
CallbackEnv = collections.namedtuple(
"LightGBMCallbackEnv",
["model",
"cvfolds",
"iteration",
"begin_iteration",
"end_iteration",
"evaluation_result_list"])
def _format_eval_result(value, show_stdv=True):
"""format metric string"""
if len(value) == 4:
return '%s_%s:%g' % (value[0], value[1], value[2])
elif len(value) == 5:
if show_stdv:
return '%s_%s:%g+%g' % (value[0], value[1], value[2], value[4])
else:
return '%s_%s:%g' % (value[0], value[1], value[2])
else:
raise ValueError("wrong metric value")
def print_evaluation(period=1, show_stdv=True):
"""Create a callback that print evaluation result.
Parameters
----------
period : int
The period to log the evaluation results
show_stdv : bool, optional
Whether show stdv if provided
Returns
-------
callback : function
A callback that print evaluation every period iterations.
"""
def callback(env):
"""internal function"""
if len(env.evaluation_result_list) == 0 or period is False:
return
if (env.iteration % period == 0 or env.iteration + 1 == env.begin_iteration):
result = '\t'.join([_format_eval_result(x, show_stdv) for x in env.evaluation_result_list])
print('[%d]\t%s\n' % (env.iteration, result))
return callback
def record_evaluation(eval_result):
"""Create a call back that records the evaluation history into eval_result.
Parameters
----------
eval_result : dict
A dictionary to store the evaluation results.
Returns
-------
callback : function
The requested callback function.
"""
if not isinstance(eval_result, dict):
raise TypeError('eval_result has to be a dictionary')
eval_result.clear()
def init(env):
"""internal function"""
for data_name, eval_name, _ in env.evaluation_result_list:
if data_name not in eval_result:
eval_result[data_name] = {}
if eval_name not in eval_result[data_name]:
eval_result[data_name][eval_name] = []
def callback(env):
"""internal function"""
if len(eval_result) == 0:
init(env)
for data_name, eval_name, result in env.evaluation_result_list:
eval_result[data_name][eval_name].append(result)
return callback
def reset_learning_rate(learning_rates):
"""Reset learning rate after iteration 1
NOTE: the initial learning rate will still take in-effect on first iteration.
Parameters
----------
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates learning_rate in terms of
current number of round and the total number of boosting round (e.g. yields
learning rate decay)
- list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round, total_boost_round)
Returns
-------
callback : function
The requested callback function.
"""
def callback(env):
"""internal function"""
booster = env.model
i = env.iteration
if isinstance(learning_rates, list):
if len(learning_rates) != env.end_iteration:
raise ValueError("Length of list 'learning_rates' has to equal 'num_boost_round'.")
booster.reset_parameter({'learning_rate':learning_rates[i]})
else:
booster.reset_parameter({'learning_rate':learning_rates(i, env.end_iteration)})
callback.before_iteration = True
return callback
def early_stop(stopping_rounds, verbose=True):
"""Create a callback that activates early stoppping.
Activates early stopping.
Requires at least one validation data and one metric
If there's more than one, will check all of them
Parameters
----------
stopp_rounds : int
The stopping rounds before the trend occur.
verbose : optional, bool
Whether to print message about early stopping information.
Returns
-------
callback : function
The requested callback function.
"""
is_init = False
def init(env):
"""internal function"""
bst = env.model
if len(env.evaluation_result_list) == 0:
raise ValueError('For early stopping you need at least one set in evals.')
if verbose:
msg = "Will train until hasn't improved in {} rounds.\n"
print(msg.format(stopping_rounds))
best_scores = [ float('-inf') for _ in range(len(env.evaluation_result_list))]
best_iter = [ 0 for _ in range(len(env.evaluation_result_list))]
if verbose:
best_msg = [ "" for _ in range(len(env.evaluation_result_list))]
factor_to_bigger_better = [-1.0 for _ in range(len(env.evaluation_result_list))]
for i in range(len(env.evaluation_result_list)):
if evaluation.evaluation_result_list[i][3]:
factor_to_bigger_better[i] = 1.0
is_init = True
def callback(env):
"""internal function"""
if not is_init:
init(env)
for i in range(len(env.evaluation_result_list)):
score = env.evaluation_result_list[i][2] * factor_to_bigger_better[i]
if score > best_score[i]:
best_score[i] = score
best_iter[i] = env.iteration
if verbose:
best_msg[i] = '[%d]\t%s' % ( env.iteration,
'\t'.join([_format_eval_result(x) for x in env.evaluation_result_list]))
else:
if env.iteration - best_iter[i] >= stopping_rounds:
if env.model is not None:
env.model.set_attr(best_iteration=str(best_iter[i]))
if verbose:
print('early stopping, best message is:\n {} '.format(best_msg[i]))
raise EarlyStopException(best_iter[i])
return callback
"""Training Library containing training routines of LightGBM."""
from __future__ import absolute_import
import collections
import numpy as np
from .basic import LightGBMError, Predictor, Dataset, Booster, is_str
from . import callback
def _construct_dataset(x, y, reference=None,
params=None, other_fields=None, predictor=None):
if 'max_bin' in params:
max_bin = int(params['max_bin'])
else:
max_bin = 255
weight = None
group = None
init_score = None
if other_fields is not None:
if not is isinstance(other_fields, dict):
raise TypeError("other filed data should be dict type")
weight = None if 'weight' not in other_fields else other_fields['weight']
group = None if 'group' not in other_fields else other_fields['group']
init_score = None if 'init_score' not in other_fields else other_fields['init_score']
if reference is None:
ret = Dataset(x, y, max_bin=max_bin,
weight=weight, group=group, predictor=predictor, params=params)
else:
ret = reference.create_valid(x, y, weight, group, params=params)
if init_score is not None:
ret.set_init_score(init_score)
return ret
def train(params, train_data, num_boost_round=100,
valid_datas=None, valid_names=None,
fobj=None, feval=None, init_model=None,
train_fields=None, valid_fields=None,
early_stopping_rounds=None, out_eval_result=None,
verbose_eval=True, learning_rates=None, callbacks=None):
"""Train with given parameters.
Parameters
----------
params : dict
params.
train_data : pair, (X, y)
Data to be trained.
num_boost_round: int
Number of boosting iterations.
valid_datas: list of pairs (valid_X, valid_y)
List of data to be evaluated during training
valid_names: list of string
names of valid_datas
fobj : function
Customized objective function.
feval : function
Customized evaluation function.
Note: should return (eval_name, eval_result, is_higher_better) of list of this
init_model : file name of lightgbm model or 'Booster' instance
model used for continued train
train_fields : dict
other data file in training data. e.g. train_fields['weight'] is weight data
support fields: weight, group, init_score
valid_fields : dict
other data file in training data. e.g. valid_fields[0]['weight'] is weight data for first valid data
support fields: weight, group, init_score
early_stopping_rounds: int
Activates early stopping.
Requires at least one validation data and one metric
If there's more than one, will check all of them
Returns the model with (best_iter + early_stopping_rounds)
If early stopping occurs, the model will add 'best_iteration' field
out_eval_result: dict or None
This dictionary used to store all evaluation results of all the items in valid_datas.
Example: with a valid_datas containing [dtest, dtrain] and valid_names containing ['eval', 'train'] and
a paramater containing ('metric':'logloss')
Returns: {'train': {'logloss': ['0.48253', '0.35953', ...]},
'eval': {'logloss': ['0.480385', '0.357756', ...]}}
passed with None means no using this function
verbose_eval : bool or int
Requires at least one item in evals.
If `verbose_eval` is True then the evaluation metric on the validation set is
printed at each boosting stage.
If `verbose_eval` is an integer then the evaluation metric on the validation set
is printed at every given `verbose_eval` boosting stage. The last boosting stage
/ the boosting stage found by using `early_stopping_rounds` is also printed.
Example: with verbose_eval=4 and at least one item in evals, an evaluation metric
is printed every 4 boosting stages, instead of every boosting stage.
learning_rates: list or function
List of learning rate for each boosting round
or a customized function that calculates learning_rate in terms of
current number of round and the total number of boosting round (e.g. yields
learning rate decay)
- list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round, total_boost_round)
callbacks : list of callback functions
List of callback functions that are applied at end of each iteration.
Returns
-------
booster : a trained booster model
"""
"""create predictor first"""
if is_str(init_model):
predictor = Predictor(model_file=init_model)
elif isinstance(init_model, Booster):
predictor = Booster.to_predictor()
elif isinstance(init_model, Predictor):
predictor = init_model
else:
predictor = None
"""create dataset"""
train_set = _construct_dataset(train_data[0], train_data[1], None, params, train_fields, predictor, silent)
is_valid_contain_train = False
train_data_name = "training"
valid_sets = []
name_valid_sets = []
if valid_datas is not None:
for i in range(len(valid_datas)):
other_fields = None if valid_fields is None else valid_fields[i]
"""reduce cost for prediction training data"""
if valid_datas[i] is train_data:
is_valid_contain_train = True
train_data_name = valid_names[i]
continue
valid_set = _construct_dataset(
valid_datas[i][0],
valid_datas[i][1],
train_set,
params,
other_fields,
predictor,
silent)
valid_sets.append(valid_set)
name_valid_sets.append(valid_names[i])
"""process callbacks"""
callbacks = [] if callbacks is None else callbacks
# Most of legacy advanced options becomes callbacks
if isinstance(verbose_eval, bool) and verbose_eval:
callbacks.append(callback.print_evaluation())
else:
if isinstance(verbose_eval, int):
callbacks.append(callback.print_evaluation(verbose_eval))
if early_stopping_rounds is not None:
callbacks.append(callback.early_stop(early_stopping_rounds,
verbose=bool(verbose_eval)))
if learning_rates is not None:
callbacks.append(callback.reset_learning_rate(learning_rates))
if evals_result is not None:
callbacks.append(callback.record_evaluation(evals_result))
callbacks_before_iter = [
cb for cb in callbacks if cb.__dict__.get('before_iteration', False)]
callbacks_after_iter = [
cb for cb in callbacks if not cb.__dict__.get('before_iteration', False)]
"""construct booster"""
booster = Booster(params=params, train_set=train_set, silent=silent)
if is_valid_contain_train:
booster.set_train_data_name(train_data_name)
for i in range(len(valid_sets)):
booster.add_valid(valid_sets[i], name_valid_sets[i])
"""start training"""
for i in range(num_boost_round):
for cb in callbacks_before_iter:
cb(CallbackEnv(model=booster,
cvfolds=None,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=None))
booster.update(fobj=fobj)
evaluation_result_list = []
# check evaluation result.
if len(valid_sets) != 0:
if is_valid_contain_train:
evaluation_result_list.extend(booster.eval_train(feval))
evaluation_result_list.extend(booster.eval_valid(feval))
try:
for cb in callbacks_after_iter:
cb(CallbackEnv(model=booster,
cvfolds=None,
iteration=i,
begin_iteration=0,
end_iteration=num_boost_round,
evaluation_result_list=evaluation_result_list))
except EarlyStopException:
break
if booster.attr('best_iteration') is not None:
booster.best_iteration = int(booster.attr('best_iteration'))
else:
booster.best_iteration = num_boost_round - 1
return num_boost_round
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment