"tests/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "23403a7c2c4d38961506968291f13c3f344ccb6f"
Commit f7d190aa authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python] add keep_training_booster (#673)

* add keep_training_booster

* use model string

* reset handle; free dataset
parent ac73638f
......@@ -1245,7 +1245,7 @@ class Booster(object):
self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_file)
elif 'model_str' in params:
self.__load_model_from_string(params['model_str'])
self._load_model_from_string(params['model_str'])
else:
raise TypeError('Need at least one training dataset or model file to create booster instance')
......@@ -1257,7 +1257,7 @@ class Booster(object):
return self.__deepcopy__(None)
def __deepcopy__(self, _):
model_str = self.__save_model_to_string()
model_str = self._save_model_to_string()
booster = Booster({'model_str': model_str})
booster.pandas_categorical = self.pandas_categorical
return booster
......@@ -1268,7 +1268,7 @@ class Booster(object):
this.pop('train_set', None)
this.pop('valid_sets', None)
if handle is not None:
this["handle"] = self.__save_model_to_string()
this["handle"] = self._save_model_to_string()
return this
def __setstate__(self, state):
......@@ -1286,6 +1286,7 @@ class Booster(object):
def free_dataset(self):
self.__dict__.pop('train_set', None)
self.__dict__.pop('valid_sets', None)
self.__num_dataset = 0
def set_train_data_name(self, name):
self.__train_data_name = name
......@@ -1505,7 +1506,7 @@ class Booster(object):
c_str(filename)))
_save_pandas_categorical(filename, self.pandas_categorical)
def __load_model_from_string(self, model_str):
def _load_model_from_string(self, model_str):
"""[Private] Load model from string"""
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterLoadModelFromString(
......@@ -1518,7 +1519,7 @@ class Booster(object):
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
def __save_model_to_string(self, num_iteration=-1):
def _save_model_to_string(self, num_iteration=-1):
"""[Private] Save model to string"""
if num_iteration <= 0:
num_iteration = self.best_iteration
......
......@@ -19,7 +19,8 @@ def train(params, train_set, num_boost_round=100,
fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, callbacks=None):
verbose_eval=True, learning_rates=None,
keep_training_booster=False, callbacks=None):
"""
Train with given parameters.
......@@ -80,6 +81,10 @@ def train(params, train_set, num_boost_round=100,
in terms of current number of round (e.g. yields learning rate decay)
- list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round)
keep_training_booster : boolean
Whether the return booster will be used to keep training.
If false, will convert into _InnerPredictor before return.
You can still use _InnerPredictor as init_model for future continue training.
callbacks : list of callback functions
List of callback functions that are applied at each iteration.
See Callbacks in Python-API.md for more information.
......@@ -200,6 +205,9 @@ def train(params, train_set, num_boost_round=100,
booster.best_score = collections.defaultdict(dict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
booster._load_model_from_string(booster._save_model_to_string())
booster.free_dataset()
return booster
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment