Commit 861de1c1 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

improved model loading routines (#1979)

parent 40486b6c
......@@ -306,22 +306,29 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
def _load_pandas_categorical(file_name=None, model_str=None):
pandas_key = 'pandas_categorical:'
offset = -len(pandas_key)
if file_name is not None:
with open(file_name, 'r') as f:
lines = f.readlines()
last_line = lines[-1]
if last_line.strip() == "":
last_line = lines[-2]
if last_line.startswith('pandas_categorical:'):
return json.loads(last_line[len('pandas_categorical:'):])
max_offset = -os.path.getsize(file_name)
with open(file_name, 'rb') as f:
while True:
if offset < max_offset:
offset = max_offset
f.seek(offset, os.SEEK_END)
lines = f.readlines()
if len(lines) >= 2:
break
offset *= 2
last_line = decode_string(lines[-1]).strip()
if not last_line.startswith(pandas_key):
last_line = decode_string(lines[-2]).strip()
elif model_str is not None:
lines = model_str.split('\n')
last_line = lines[-1]
if last_line.strip() == "":
last_line = lines[-2]
if last_line.startswith('pandas_categorical:'):
return json.loads(last_line[len('pandas_categorical:'):])
return None
idx = model_str.rfind('\n', 0, offset)
last_line = model_str[idx:].strip()
if last_line.startswith(pandas_key):
return json.loads(last_line[len(pandas_key):])
else:
return None
class _InnerPredictor(object):
......
......@@ -349,8 +349,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
std::unordered_map<std::string, std::string> key_vals;
while (p < end) {
auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (!Common::StartsWith(cur_line, "Tree=")) {
auto strs = Common::Split(cur_line.c_str(), '=');
if (strs.size() == 1) {
......@@ -442,8 +442,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
if (!key_vals.count("tree_sizes")) {
while (p < end) {
auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (Common::StartsWith(cur_line, "Tree=")) {
p += line_len;
p = Common::SkipNewLine(p);
......@@ -491,8 +491,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
std::stringstream ss;
while (p < end) {
auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) {
std::string cur_line(p, line_len);
if (cur_line == std::string("parameters:")) {
is_inparameter = true;
} else if (cur_line == std::string("end of parameters")) {
......
......@@ -551,9 +551,11 @@ class TestEngine(unittest.TestCase):
"B": np.random.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30)})
cat_cols = []
for col in ["A", "B", "C", "D"]:
X[col] = X[col].astype('category')
X_test[col] = X_test[col].astype('category')
cat_cols.append(X[col].cat.categories.tolist())
params = {
'objective': 'binary',
'metric': 'binary_logloss',
......@@ -588,6 +590,12 @@ class TestEngine(unittest.TestCase):
np.testing.assert_almost_equal(pred0, pred4)
np.testing.assert_almost_equal(pred0, pred5)
np.testing.assert_almost_equal(pred0, pred6)
self.assertListEqual(gbm0.pandas_categorical, cat_cols)
self.assertListEqual(gbm1.pandas_categorical, cat_cols)
self.assertListEqual(gbm2.pandas_categorical, cat_cols)
self.assertListEqual(gbm3.pandas_categorical, cat_cols)
self.assertListEqual(gbm4.pandas_categorical, cat_cols)
self.assertListEqual(gbm5.pandas_categorical, cat_cols)
def test_reference_chain(self):
X = np.random.normal(size=(100, 2))
......
......@@ -215,25 +215,32 @@ class TestSklearn(unittest.TestCase):
"B": np.random.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30)})
cat_cols = []
for col in ["A", "B", "C", "D"]:
X[col] = X[col].astype('category')
X_test[col] = X_test[col].astype('category')
cat_cols.append(X[col].cat.categories.tolist())
gbm0 = lgb.sklearn.LGBMClassifier().fit(X, y)
pred0 = list(gbm0.predict(X_test))
gbm1 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=[0])
pred1 = list(gbm1.predict(X_test))
pred0 = gbm0.predict(X_test)
gbm1 = lgb.sklearn.LGBMClassifier().fit(X, pd.Series(y), categorical_feature=[0])
pred1 = gbm1.predict(X_test)
gbm2 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A'])
pred2 = list(gbm2.predict(X_test))
pred2 = gbm2.predict(X_test)
gbm3 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A', 'B', 'C', 'D'])
pred3 = list(gbm3.predict(X_test))
pred3 = gbm3.predict(X_test)
gbm3.booster_.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test))
pred_prob = list(gbm0.predict_proba(X_test)[:, 1])
pred4 = gbm4.predict(X_test)
pred_prob = gbm0.predict_proba(X_test)[:, 1]
np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2)
np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred_prob, pred4)
self.assertListEqual(gbm0.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm1.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm2.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm3.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm4.pandas_categorical, cat_cols)
def test_predict(self):
iris = load_iris()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment