Unverified Commit 5b5b9823 authored by Nikita Titov's avatar Nikita Titov Committed by GitHub
Browse files

[python] make possibility to create Booster from string official (#2098)

parent 0a4a7a86
...@@ -397,7 +397,7 @@ class _InnerPredictor(object): ...@@ -397,7 +397,7 @@ class _InnerPredictor(object):
self.num_total_iteration = out_num_iterations.value self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = None self.pandas_categorical = None
else: else:
raise TypeError('Need Model file or Booster handle to create a predictor') raise TypeError('Need model_file or booster_handle to create a predictor')
pred_parameter = {} if pred_parameter is None else pred_parameter pred_parameter = {} if pred_parameter is None else pred_parameter
self.pred_parameter = param_dict_to_str(pred_parameter) self.pred_parameter = param_dict_to_str(pred_parameter)
...@@ -1578,7 +1578,7 @@ class Dataset(object): ...@@ -1578,7 +1578,7 @@ class Dataset(object):
class Booster(object): class Booster(object):
"""Booster in LightGBM.""" """Booster in LightGBM."""
def __init__(self, params=None, train_set=None, model_file=None, silent=False): def __init__(self, params=None, train_set=None, model_file=None, model_str=None, silent=False):
"""Initialize the Booster. """Initialize the Booster.
Parameters Parameters
...@@ -1589,6 +1589,8 @@ class Booster(object): ...@@ -1589,6 +1589,8 @@ class Booster(object):
Training dataset. Training dataset.
model_file : string or None, optional (default=None) model_file : string or None, optional (default=None)
Path to the model file. Path to the model file.
model_str : string or None, optional (default=None)
Model will be loaded from this string.
silent : bool, optional (default=False) silent : bool, optional (default=False)
Whether to print messages during construction. Whether to print messages during construction.
""" """
...@@ -1666,10 +1668,11 @@ class Booster(object): ...@@ -1666,10 +1668,11 @@ class Booster(object):
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file) self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
elif 'model_str' in params: elif model_str is not None:
self.model_from_string(params['model_str'], False) self.model_from_string(model_str, not silent)
else: else:
raise TypeError('Need at least one training dataset or model file to create booster instance') raise TypeError('Need at least one training dataset or model file or model string '
'to create Booster instance')
self.params = params self.params = params
def __del__(self): def __del__(self):
...@@ -1689,7 +1692,7 @@ class Booster(object): ...@@ -1689,7 +1692,7 @@ class Booster(object):
def __deepcopy__(self, _): def __deepcopy__(self, _):
model_str = self.model_to_string(num_iteration=-1) model_str = self.model_to_string(num_iteration=-1)
booster = Booster({'model_str': model_str}) booster = Booster(model_str=model_str)
return booster return booster
def __getstate__(self): def __getstate__(self):
......
...@@ -583,7 +583,7 @@ class TestEngine(unittest.TestCase): ...@@ -583,7 +583,7 @@ class TestEngine(unittest.TestCase):
model_str = gbm4.model_to_string() model_str = gbm4.model_to_string()
gbm4.model_from_string(model_str, False) gbm4.model_from_string(model_str, False)
pred5 = gbm4.predict(X_test) pred5 = gbm4.predict(X_test)
gbm5 = lgb.Booster({'model_str': model_str}) gbm5 = lgb.Booster(model_str=model_str)
pred6 = gbm5.predict(X_test) pred6 = gbm5.predict(X_test)
np.testing.assert_almost_equal(pred0, pred1) np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2) np.testing.assert_almost_equal(pred0, pred2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment