Commit a4a0235d authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

use json instead of repr/eval for pandas_categorical (#247)

* use json instead of repr/eval for pandas_categorical

* fix json dumps with numpy data

* add more test cases
parent 9c5dbdde
...@@ -11,8 +11,9 @@ from tempfile import NamedTemporaryFile ...@@ -11,8 +11,9 @@ from tempfile import NamedTemporaryFile
import numpy as np import numpy as np
import scipy.sparse import scipy.sparse
from .compat import (DataFrame, Series, integer_types, json, numeric_types, from .compat import (DataFrame, Series, integer_types, json,
range_, string_type) json_default_with_numpy, numeric_types, range_,
string_type)
from .libpath import find_lib_path from .libpath import find_lib_path
...@@ -271,6 +272,19 @@ def _label_from_pandas(label): ...@@ -271,6 +272,19 @@ def _label_from_pandas(label):
return label return label
def _save_pandas_categorical(file_name, pandas_categorical):
with open(file_name, 'a') as f:
f.write('\npandas_categorical:' + json.dumps(pandas_categorical, default=json_default_with_numpy))
def _load_pandas_categorical(file_name):
with open(file_name, 'r') as f:
last_line = f.readlines()[-1]
if last_line.startswith('pandas_categorical:'):
return json.loads(last_line[len('pandas_categorical:'):])
return None
class _InnerPredictor(object): class _InnerPredictor(object):
""" """
A _InnerPredictor of LightGBM. A _InnerPredictor of LightGBM.
...@@ -302,12 +316,7 @@ class _InnerPredictor(object): ...@@ -302,12 +316,7 @@ class _InnerPredictor(object):
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.num_class = out_num_class.value self.num_class = out_num_class.value
self.num_total_iteration = out_num_iterations.value self.num_total_iteration = out_num_iterations.value
with open(model_file, 'r') as f: self.pandas_categorical = _load_pandas_categorical(model_file)
last_line = f.readlines()[-1]
if last_line.startswith('pandas_categorical:'):
self.pandas_categorical = eval(last_line[len('pandas_categorical:'):])
else:
self.pandas_categorical = None
elif booster_handle is not None: elif booster_handle is not None:
self.__is_manage_handle = False self.__is_manage_handle = False
self.handle = booster_handle self.handle = booster_handle
...@@ -1207,12 +1216,7 @@ class Booster(object): ...@@ -1207,12 +1216,7 @@ class Booster(object):
self.handle, self.handle,
ctypes.byref(out_num_class))) ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value self.__num_class = out_num_class.value
with open(model_file, 'r') as f: self.pandas_categorical = _load_pandas_categorical(model_file)
last_line = f.readlines()[-1]
if last_line.startswith('pandas_categorical:'):
self.pandas_categorical = eval(last_line[len('pandas_categorical:'):])
else:
self.pandas_categorical = None
elif 'model_str' in params: elif 'model_str' in params:
self.__load_model_from_string(params['model_str']) self.__load_model_from_string(params['model_str'])
else: else:
...@@ -1468,8 +1472,7 @@ class Booster(object): ...@@ -1468,8 +1472,7 @@ class Booster(object):
self.handle, self.handle,
ctypes.c_int(num_iteration), ctypes.c_int(num_iteration),
c_str(filename))) c_str(filename)))
with open(filename, 'a') as f: _save_pandas_categorical(filename, self.pandas_categorical)
f.write('\npandas_categorical:' + repr(self.pandas_categorical))
def __load_model_from_string(self, model_str): def __load_model_from_string(self, model_str):
"""[Private] Load model from string""" """[Private] Load model from string"""
......
...@@ -6,6 +6,8 @@ from __future__ import absolute_import ...@@ -6,6 +6,8 @@ from __future__ import absolute_import
import inspect import inspect
import sys import sys
import numpy as np
is_py3 = (sys.version_info[0] == 3) is_py3 = (sys.version_info[0] == 3)
"""compatibility between python2 and python3""" """compatibility between python2 and python3"""
...@@ -36,6 +38,16 @@ except (ImportError, SyntaxError): ...@@ -36,6 +38,16 @@ except (ImportError, SyntaxError):
# because of u'...' Unicode literals. # because of u'...' Unicode literals.
import json import json
def json_default_with_numpy(obj):
if isinstance(obj, (np.integer, np.floating, np.bool_)):
return obj.item()
elif isinstance(obj, np.ndarray):
return obj.tolist()
else:
return obj
"""pandas""" """pandas"""
try: try:
from pandas import Series, DataFrame from pandas import Series, DataFrame
......
...@@ -146,15 +146,18 @@ class TestEngine(unittest.TestCase): ...@@ -146,15 +146,18 @@ class TestEngine(unittest.TestCase):
@unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed') @unittest.skipIf(not IS_PANDAS_INSTALLED, 'pandas not installed')
def test_pandas_categorical(self): def test_pandas_categorical(self):
X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), X = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'c', 'd'] * 75), # str
"B": np.random.permutation([1, 2, 3] * 100)}) "B": np.random.permutation([1, 2, 3] * 100), # int
X["A"] = X["A"].astype('category') "C": np.random.permutation([0.1, 0.2, -0.1, -0.1, 0.2] * 60), # float
X["B"] = X["B"].astype('category') "D": np.random.permutation([True, False] * 150)}) # bool
y = np.random.permutation([0, 1] * 150) y = np.random.permutation([0, 1] * 150)
X_test = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'e'] * 20), X_test = pd.DataFrame({"A": np.random.permutation(['a', 'b', 'e'] * 20),
"B": np.random.permutation([1, 3] * 30)}) "B": np.random.permutation([1, 3] * 30),
X_test["A"] = X_test["A"].astype('category') "C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
X_test["B"] = X_test["B"].astype('category') "D": np.random.permutation([True, False] * 30)})
for col in ["A", "B", "C", "D"]:
X[col] = X[col].astype('category')
X_test[col] = X_test[col].astype('category')
params = { params = {
'objective': 'binary', 'objective': 'binary',
'metric': 'binary_logloss', 'metric': 'binary_logloss',
...@@ -173,7 +176,7 @@ class TestEngine(unittest.TestCase): ...@@ -173,7 +176,7 @@ class TestEngine(unittest.TestCase):
pred2 = list(gbm2.predict(X_test)) pred2 = list(gbm2.predict(X_test))
lgb_train = lgb.Dataset(X, y) lgb_train = lgb.Dataset(X, y)
gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False, gbm3 = lgb.train(params, lgb_train, num_boost_round=10, verbose_eval=False,
categorical_feature=['A', 'B']) categorical_feature=['A', 'B', 'C', 'D'])
pred3 = list(gbm3.predict(X_test)) pred3 = list(gbm3.predict(X_test))
lgb_train = lgb.Dataset(X, y) lgb_train = lgb.Dataset(X, y)
gbm3.save_model('categorical.model') gbm3.save_model('categorical.model')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment