Unverified Commit 5b5b9823 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] make possibility to create Booster from string official (#2098)

parent 0a4a7a86
......@@ -397,7 +397,7 @@ class _InnerPredictor(object):
self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = None
else:
raise TypeError('Need Model file or Booster handle to create a predictor')
raise TypeError('Need model_file or booster_handle to create a predictor')
pred_parameter = {} if pred_parameter is None else pred_parameter
self.pred_parameter = param_dict_to_str(pred_parameter)
......@@ -1578,7 +1578,7 @@ class Dataset(object):
class Booster(object):
"""Booster in LightGBM."""
def __init__(self, params=None, train_set=None, model_file=None, silent=False):
def __init__(self, params=None, train_set=None, model_file=None, model_str=None, silent=False):
"""Initialize the Booster.
Parameters
......@@ -1589,6 +1589,8 @@ class Booster(object):
Training dataset.
model_file : string or None, optional (default=None)
Path to the model file.
model_str : string or None, optional (default=None)
Model will be loaded from this string.
silent : bool, optional (default=False)
Whether to print messages during construction.
"""
......@@ -1666,10 +1668,11 @@ class Booster(object):
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
elif 'model_str' in params:
self.model_from_string(params['model_str'], False)
elif model_str is not None:
self.model_from_string(model_str, not silent)
else:
raise TypeError('Need at least one training dataset or model file to create booster instance')
raise TypeError('Need at least one training dataset or model file or model string '
'to create Booster instance')
self.params = params
def __del__(self):
......@@ -1689,7 +1692,7 @@ class Booster(object):
def __deepcopy__(self, _):
model_str = self.model_to_string(num_iteration=-1)
booster = Booster({'model_str': model_str})
booster = Booster(model_str=model_str)
return booster
def __getstate__(self):
......
......@@ -583,7 +583,7 @@ class TestEngine(unittest.TestCase):
model_str = gbm4.model_to_string()
gbm4.model_from_string(model_str, False)
pred5 = gbm4.predict(X_test)
gbm5 = lgb.Booster({'model_str': model_str})
gbm5 = lgb.Booster(model_str=model_str)
pred6 = gbm5.predict(X_test)
np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment