Commit f7d190aa authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

[python] add keep_training_booster (#673)

* add keep_training_booster

* use model string

* reset handle; free dataset
parent ac73638f
...@@ -1245,7 +1245,7 @@ class Booster(object): ...@@ -1245,7 +1245,7 @@ class Booster(object):
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_file) self.pandas_categorical = _load_pandas_categorical(model_file)
elif 'model_str' in params: elif 'model_str' in params:
self.__load_model_from_string(params['model_str']) self._load_model_from_string(params['model_str'])
else: else:
raise TypeError('Need at least one training dataset or model file to create booster instance') raise TypeError('Need at least one training dataset or model file to create booster instance')
...@@ -1257,7 +1257,7 @@ class Booster(object): ...@@ -1257,7 +1257,7 @@ class Booster(object):
return self.__deepcopy__(None) return self.__deepcopy__(None)
def __deepcopy__(self, _): def __deepcopy__(self, _):
model_str = self.__save_model_to_string() model_str = self._save_model_to_string()
booster = Booster({'model_str': model_str}) booster = Booster({'model_str': model_str})
booster.pandas_categorical = self.pandas_categorical booster.pandas_categorical = self.pandas_categorical
return booster return booster
...@@ -1268,7 +1268,7 @@ class Booster(object): ...@@ -1268,7 +1268,7 @@ class Booster(object):
this.pop('train_set', None) this.pop('train_set', None)
this.pop('valid_sets', None) this.pop('valid_sets', None)
if handle is not None: if handle is not None:
this["handle"] = self.__save_model_to_string() this["handle"] = self._save_model_to_string()
return this return this
def __setstate__(self, state): def __setstate__(self, state):
...@@ -1286,6 +1286,7 @@ class Booster(object): ...@@ -1286,6 +1286,7 @@ class Booster(object):
def free_dataset(self): def free_dataset(self):
self.__dict__.pop('train_set', None) self.__dict__.pop('train_set', None)
self.__dict__.pop('valid_sets', None) self.__dict__.pop('valid_sets', None)
self.__num_dataset = 0
def set_train_data_name(self, name): def set_train_data_name(self, name):
self.__train_data_name = name self.__train_data_name = name
...@@ -1505,7 +1506,7 @@ class Booster(object): ...@@ -1505,7 +1506,7 @@ class Booster(object):
c_str(filename))) c_str(filename)))
_save_pandas_categorical(filename, self.pandas_categorical) _save_pandas_categorical(filename, self.pandas_categorical)
def __load_model_from_string(self, model_str): def _load_model_from_string(self, model_str):
"""[Private] Load model from string""" """[Private] Load model from string"""
out_num_iterations = ctypes.c_int(0) out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterLoadModelFromString( _safe_call(_LIB.LGBM_BoosterLoadModelFromString(
...@@ -1518,7 +1519,7 @@ class Booster(object): ...@@ -1518,7 +1519,7 @@ class Booster(object):
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
def __save_model_to_string(self, num_iteration=-1): def _save_model_to_string(self, num_iteration=-1):
"""[Private] Save model to string""" """[Private] Save model to string"""
if num_iteration <= 0: if num_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
......
...@@ -19,7 +19,8 @@ def train(params, train_set, num_boost_round=100, ...@@ -19,7 +19,8 @@ def train(params, train_set, num_boost_round=100,
fobj=None, feval=None, init_model=None, fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto', feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, evals_result=None, early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, callbacks=None): verbose_eval=True, learning_rates=None,
keep_training_booster=False, callbacks=None):
""" """
Train with given parameters. Train with given parameters.
...@@ -80,6 +81,10 @@ def train(params, train_set, num_boost_round=100, ...@@ -80,6 +81,10 @@ def train(params, train_set, num_boost_round=100,
in terms of current number of round (e.g. yields learning rate decay) in terms of current number of round (e.g. yields learning rate decay)
- list l: learning_rate = l[current_round] - list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round) - function f: learning_rate = f(current_round)
keep_training_booster : boolean
Whether the return booster will be used to keep training.
If false, will convert into _InnerPredictor before return.
You can still use _InnerPredictor as init_model for future continue training.
callbacks : list of callback functions callbacks : list of callback functions
List of callback functions that are applied at each iteration. List of callback functions that are applied at each iteration.
See Callbacks in Python-API.md for more information. See Callbacks in Python-API.md for more information.
...@@ -200,6 +205,9 @@ def train(params, train_set, num_boost_round=100, ...@@ -200,6 +205,9 @@ def train(params, train_set, num_boost_round=100,
booster.best_score = collections.defaultdict(dict) booster.best_score = collections.defaultdict(dict)
for dataset_name, eval_name, score, _ in evaluation_result_list: for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
booster._load_model_from_string(booster._save_model_to_string())
booster.free_dataset()
return booster return booster
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment