Commit 861de1c1 authored by Nikita Titov's avatar Nikita Titov Committed by Guolin Ke
Browse files

improved model loading routines (#1979)

parent 40486b6c
...@@ -306,21 +306,28 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None): ...@@ -306,21 +306,28 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
def _load_pandas_categorical(file_name=None, model_str=None): def _load_pandas_categorical(file_name=None, model_str=None):
pandas_key = 'pandas_categorical:'
offset = -len(pandas_key)
if file_name is not None: if file_name is not None:
with open(file_name, 'r') as f: max_offset = -os.path.getsize(file_name)
with open(file_name, 'rb') as f:
while True:
if offset < max_offset:
offset = max_offset
f.seek(offset, os.SEEK_END)
lines = f.readlines() lines = f.readlines()
last_line = lines[-1] if len(lines) >= 2:
if last_line.strip() == "": break
last_line = lines[-2] offset *= 2
if last_line.startswith('pandas_categorical:'): last_line = decode_string(lines[-1]).strip()
return json.loads(last_line[len('pandas_categorical:'):]) if not last_line.startswith(pandas_key):
last_line = decode_string(lines[-2]).strip()
elif model_str is not None: elif model_str is not None:
lines = model_str.split('\n') idx = model_str.rfind('\n', 0, offset)
last_line = lines[-1] last_line = model_str[idx:].strip()
if last_line.strip() == "": if last_line.startswith(pandas_key):
last_line = lines[-2] return json.loads(last_line[len(pandas_key):])
if last_line.startswith('pandas_categorical:'): else:
return json.loads(last_line[len('pandas_categorical:'):])
return None return None
......
...@@ -349,8 +349,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { ...@@ -349,8 +349,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
std::unordered_map<std::string, std::string> key_vals; std::unordered_map<std::string, std::string> key_vals;
while (p < end) { while (p < end) {
auto line_len = Common::GetLine(p); auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) { if (line_len > 0) {
std::string cur_line(p, line_len);
if (!Common::StartsWith(cur_line, "Tree=")) { if (!Common::StartsWith(cur_line, "Tree=")) {
auto strs = Common::Split(cur_line.c_str(), '='); auto strs = Common::Split(cur_line.c_str(), '=');
if (strs.size() == 1) { if (strs.size() == 1) {
...@@ -442,8 +442,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { ...@@ -442,8 +442,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
if (!key_vals.count("tree_sizes")) { if (!key_vals.count("tree_sizes")) {
while (p < end) { while (p < end) {
auto line_len = Common::GetLine(p); auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) { if (line_len > 0) {
std::string cur_line(p, line_len);
if (Common::StartsWith(cur_line, "Tree=")) { if (Common::StartsWith(cur_line, "Tree=")) {
p += line_len; p += line_len;
p = Common::SkipNewLine(p); p = Common::SkipNewLine(p);
...@@ -491,8 +491,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) { ...@@ -491,8 +491,8 @@ bool GBDT::LoadModelFromString(const char* buffer, size_t len) {
std::stringstream ss; std::stringstream ss;
while (p < end) { while (p < end) {
auto line_len = Common::GetLine(p); auto line_len = Common::GetLine(p);
std::string cur_line(p, line_len);
if (line_len > 0) { if (line_len > 0) {
std::string cur_line(p, line_len);
if (cur_line == std::string("parameters:")) { if (cur_line == std::string("parameters:")) {
is_inparameter = true; is_inparameter = true;
} else if (cur_line == std::string("end of parameters")) { } else if (cur_line == std::string("end of parameters")) {
......
...@@ -551,9 +551,11 @@ class TestEngine(unittest.TestCase): ...@@ -551,9 +551,11 @@ class TestEngine(unittest.TestCase):
"B": np.random.permutation([1, 3] * 30), "B": np.random.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15), "C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30)}) "D": np.random.permutation([True, False] * 30)})
cat_cols = []
for col in ["A", "B", "C", "D"]: for col in ["A", "B", "C", "D"]:
X[col] = X[col].astype('category') X[col] = X[col].astype('category')
X_test[col] = X_test[col].astype('category') X_test[col] = X_test[col].astype('category')
cat_cols.append(X[col].cat.categories.tolist())
params = { params = {
'objective': 'binary', 'objective': 'binary',
'metric': 'binary_logloss', 'metric': 'binary_logloss',
...@@ -588,6 +590,12 @@ class TestEngine(unittest.TestCase): ...@@ -588,6 +590,12 @@ class TestEngine(unittest.TestCase):
np.testing.assert_almost_equal(pred0, pred4) np.testing.assert_almost_equal(pred0, pred4)
np.testing.assert_almost_equal(pred0, pred5) np.testing.assert_almost_equal(pred0, pred5)
np.testing.assert_almost_equal(pred0, pred6) np.testing.assert_almost_equal(pred0, pred6)
self.assertListEqual(gbm0.pandas_categorical, cat_cols)
self.assertListEqual(gbm1.pandas_categorical, cat_cols)
self.assertListEqual(gbm2.pandas_categorical, cat_cols)
self.assertListEqual(gbm3.pandas_categorical, cat_cols)
self.assertListEqual(gbm4.pandas_categorical, cat_cols)
self.assertListEqual(gbm5.pandas_categorical, cat_cols)
def test_reference_chain(self): def test_reference_chain(self):
X = np.random.normal(size=(100, 2)) X = np.random.normal(size=(100, 2))
......
...@@ -215,25 +215,32 @@ class TestSklearn(unittest.TestCase): ...@@ -215,25 +215,32 @@ class TestSklearn(unittest.TestCase):
"B": np.random.permutation([1, 3] * 30), "B": np.random.permutation([1, 3] * 30),
"C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15), "C": np.random.permutation([0.1, -0.1, 0.2, 0.2] * 15),
"D": np.random.permutation([True, False] * 30)}) "D": np.random.permutation([True, False] * 30)})
cat_cols = []
for col in ["A", "B", "C", "D"]: for col in ["A", "B", "C", "D"]:
X[col] = X[col].astype('category') X[col] = X[col].astype('category')
X_test[col] = X_test[col].astype('category') X_test[col] = X_test[col].astype('category')
cat_cols.append(X[col].cat.categories.tolist())
gbm0 = lgb.sklearn.LGBMClassifier().fit(X, y) gbm0 = lgb.sklearn.LGBMClassifier().fit(X, y)
pred0 = list(gbm0.predict(X_test)) pred0 = gbm0.predict(X_test)
gbm1 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=[0]) gbm1 = lgb.sklearn.LGBMClassifier().fit(X, pd.Series(y), categorical_feature=[0])
pred1 = list(gbm1.predict(X_test)) pred1 = gbm1.predict(X_test)
gbm2 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A']) gbm2 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A'])
pred2 = list(gbm2.predict(X_test)) pred2 = gbm2.predict(X_test)
gbm3 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A', 'B', 'C', 'D']) gbm3 = lgb.sklearn.LGBMClassifier().fit(X, y, categorical_feature=['A', 'B', 'C', 'D'])
pred3 = list(gbm3.predict(X_test)) pred3 = gbm3.predict(X_test)
gbm3.booster_.save_model('categorical.model') gbm3.booster_.save_model('categorical.model')
gbm4 = lgb.Booster(model_file='categorical.model') gbm4 = lgb.Booster(model_file='categorical.model')
pred4 = list(gbm4.predict(X_test)) pred4 = gbm4.predict(X_test)
pred_prob = list(gbm0.predict_proba(X_test)[:, 1]) pred_prob = gbm0.predict_proba(X_test)[:, 1]
np.testing.assert_almost_equal(pred0, pred1) np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2) np.testing.assert_almost_equal(pred0, pred2)
np.testing.assert_almost_equal(pred0, pred3) np.testing.assert_almost_equal(pred0, pred3)
np.testing.assert_almost_equal(pred_prob, pred4) np.testing.assert_almost_equal(pred_prob, pred4)
self.assertListEqual(gbm0.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm1.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm2.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm3.booster_.pandas_categorical, cat_cols)
self.assertListEqual(gbm4.pandas_categorical, cat_cols)
def test_predict(self): def test_predict(self):
iris = load_iris() iris = load_iris()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment