Commit af3c4f89 authored by Nikita Titov's avatar Nikita Titov Committed by Qiwei Ye
Browse files

[python] save pandas_categorical to model string and JSON (#1766)

* add pandas_categorical to returned json

* save pandas_categorical to model file

* added regression test

* removed excess conversion to list

* removed deprecated line

* hotfix
parent 9a2bf266
...@@ -295,13 +295,18 @@ def _label_from_pandas(label): ...@@ -295,13 +295,18 @@ def _label_from_pandas(label):
return label return label
def _save_pandas_categorical(file_name, pandas_categorical): def _dump_pandas_categorical(pandas_categorical, file_name=None):
pandas_str = ('\npandas_categorical:'
+ json.dumps(pandas_categorical, default=json_default_with_numpy)
+ '\n')
if file_name is not None:
with open(file_name, 'a') as f: with open(file_name, 'a') as f:
f.write('\npandas_categorical:' f.write(pandas_str)
+ json.dumps(pandas_categorical, default=json_default_with_numpy) + '\n') return pandas_str
def _load_pandas_categorical(file_name): def _load_pandas_categorical(file_name=None, model_str=None):
if file_name is not None:
with open(file_name, 'r') as f: with open(file_name, 'r') as f:
lines = f.readlines() lines = f.readlines()
last_line = lines[-1] last_line = lines[-1]
...@@ -309,6 +314,13 @@ def _load_pandas_categorical(file_name): ...@@ -309,6 +314,13 @@ def _load_pandas_categorical(file_name):
last_line = lines[-2] last_line = lines[-2]
if last_line.startswith('pandas_categorical:'): if last_line.startswith('pandas_categorical:'):
return json.loads(last_line[len('pandas_categorical:'):]) return json.loads(last_line[len('pandas_categorical:'):])
elif model_str is not None:
lines = model_str.split('\n')
last_line = lines[-1]
if last_line.strip() == "":
last_line = lines[-2]
if last_line.startswith('pandas_categorical:'):
return json.loads(last_line[len('pandas_categorical:'):])
return None return None
...@@ -350,7 +362,7 @@ class _InnerPredictor(object): ...@@ -350,7 +362,7 @@ class _InnerPredictor(object):
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.num_class = out_num_class.value self.num_class = out_num_class.value
self.num_total_iteration = out_num_iterations.value self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = _load_pandas_categorical(model_file) self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
elif booster_handle is not None: elif booster_handle is not None:
self.__is_manage_handle = False self.__is_manage_handle = False
self.handle = booster_handle self.handle = booster_handle
...@@ -1531,9 +1543,9 @@ class Booster(object): ...@@ -1531,9 +1543,9 @@ class Booster(object):
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_file) self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
elif 'model_str' in params: elif 'model_str' in params:
self.model_from_string(params['model_str']) self.model_from_string(params['model_str'], False)
else: else:
raise TypeError('Need at least one training dataset or model file to create booster instance') raise TypeError('Need at least one training dataset or model file to create booster instance')
self.params = params self.params = params
...@@ -1556,7 +1568,6 @@ class Booster(object): ...@@ -1556,7 +1568,6 @@ class Booster(object):
def __deepcopy__(self, _): def __deepcopy__(self, _):
model_str = self.model_to_string(num_iteration=-1) model_str = self.model_to_string(num_iteration=-1)
booster = Booster({'model_str': model_str}) booster = Booster({'model_str': model_str})
booster.pandas_categorical = self.pandas_categorical
return booster return booster
def __getstate__(self): def __getstate__(self):
...@@ -1950,7 +1961,7 @@ class Booster(object): ...@@ -1950,7 +1961,7 @@ class Booster(object):
ctypes.c_int(start_iteration), ctypes.c_int(start_iteration),
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(filename))) c_str(filename)))
_save_pandas_categorical(filename, self.pandas_categorical) _dump_pandas_categorical(self.pandas_categorical, filename)
return self return self
def shuffle_models(self, start_iteration=0, end_iteration=-1): def shuffle_models(self, start_iteration=0, end_iteration=-1):
...@@ -2006,6 +2017,7 @@ class Booster(object): ...@@ -2006,6 +2017,7 @@ class Booster(object):
if verbose: if verbose:
print('Finished loading model, total used %d iterations' % int(out_num_iterations.value)) print('Finished loading model, total used %d iterations' % int(out_num_iterations.value))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_str=model_str)
return self return self
def model_to_string(self, num_iteration=None, start_iteration=0): def model_to_string(self, num_iteration=None, start_iteration=0):
...@@ -2050,7 +2062,9 @@ class Booster(object): ...@@ -2050,7 +2062,9 @@ class Booster(object):
ctypes.c_int64(actual_len), ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
return string_buffer.value.decode() ret = string_buffer.value.decode()
ret += _dump_pandas_categorical(self.pandas_categorical)
return ret
def dump_model(self, num_iteration=None, start_iteration=0): def dump_model(self, num_iteration=None, start_iteration=0):
"""Dump Booster to JSON format. """Dump Booster to JSON format.
...@@ -2094,7 +2108,10 @@ class Booster(object): ...@@ -2094,7 +2108,10 @@ class Booster(object):
ctypes.c_int64(actual_len), ctypes.c_int64(actual_len),
ctypes.byref(tmp_out_len), ctypes.byref(tmp_out_len),
ptr_string_buffer)) ptr_string_buffer))
return json.loads(string_buffer.value.decode()) ret = json.loads(string_buffer.value.decode())
ret['pandas_categorical'] = json.loads(json.dumps(self.pandas_categorical,
default=json_default_with_numpy))
return ret
def predict(self, data, num_iteration=None, def predict(self, data, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False, raw_score=False, pred_leaf=False, pred_contrib=False,
......
...@@ -560,26 +560,33 @@ class TestEngine(unittest.TestCase): ...@@ -560,26 +560,33 @@ class TestEngine(unittest.TestCase):
} }
lgb_train = lgb.Dataset(X, y) lgb_train = lgb.Dataset(X, y)
gbm0 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False) gbm0 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False)
pred0 = list(gbm0.predict(X_test)) pred0 = gbm0.predict(X_test)
lgb_train = lgb.Dataset(X, pd.DataFrame(y)) # also test that label can be one-column pd.DataFrame lgb_train = lgb.Dataset(X, pd.DataFrame(y)) # also test that label can be one-column pd.DataFrame
gbm1 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False, gbm1 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=[0]) categorical_feature=[0])
pred1 = list(gbm1.predict(X_test)) pred1 = gbm1.predict(X_test)
lgb_train = lgb.Dataset(X, pd.Series(y)) # also test that label can be pd.Series lgb_train = lgb.Dataset(X, pd.Series(y)) # also test that label can be pd.Series
gbm2 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False, gbm2 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=['A']) categorical_feature=['A'])
pred2 = list(gbm2.predict(X_test)) pred2 = gbm2.predict(X_test)
lgb_train = lgb.Dataset(X, y) lgb_train = lgb.Dataset(X, y)
gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False, gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=['A', 'B', 'C', 'D']) categorical_feature=['A', 'B', 'C', 'D'])
pred3 = list(gbm3.predict(X_test)) pred3 = gbm3.predict(X_test)
gbm3.save_model('categorical.model') gbm3.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model') gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test)) pred4 = gbm4.predict(X_test)
model_str = gbm4.model_to_string()
gbm4.model_from_string(model_str, False)
pred5 = gbm4.predict(X_test)
gbm5 = lgb.Booster({'model_str': model_str})
pred6 = gbm5.predict(X_test)
np.testing.assert_almost_equal(pred0, pred1) np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2) np.testing.assert_almost_equal(pred0, pred2)
np.testing.assert_almost_equal(pred0, pred3) np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred0, pred4) np.testing.assert_almost_equal(pred0, pred4)
np.testing.assert_almost_equal(pred0, pred5)
np.testing.assert_almost_equal(pred0, pred6)
def test_reference_chain(self): def test_reference_chain(self):
X = np.random.normal(size=(100, 2)) X = np.random.normal(size=(100, 2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment