Unverified Commit 1b5bec00 authored by Belinda Trotta's avatar Belinda Trotta Committed by GitHub
Browse files

Add linear leaf models to json output (fixes #4186) (#4329)



* Add linear leaf models to json output

* Add closing bracket

* Move test into test_engine.py and add asserts

* Update tests/python_package_test/test_engine.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update tests/python_package_test/test_engine.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>

* Update tests/python_package_test/test_engine.py
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
Co-authored-by: default avatarNikita Titov <nekit94-08@mail.ru>
parent 3dd4a3f9
...@@ -243,6 +243,9 @@ class Tree { ...@@ -243,6 +243,9 @@ class Tree {
/*! \brief Serialize this object to json*/ /*! \brief Serialize this object to json*/
std::string ToJSON() const; std::string ToJSON() const;
/*! \brief Serialize linear model of tree node to json*/
std::string LinearModelToJSON(int index) const;
/*! \brief Serialize this object to if-else statement*/ /*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool predict_leaf_index) const; std::string ToIfElse(int index, bool predict_leaf_index) const;
......
...@@ -417,11 +417,39 @@ std::string Tree::ToJSON() const { ...@@ -417,11 +417,39 @@ std::string Tree::ToJSON() const {
str_buf << "\"num_cat\":" << num_cat_ << "," << '\n'; str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n'; str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
if (num_leaves_ == 1) { if (num_leaves_ == 1) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n'; if (is_linear_) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << ", " << "\n";
str_buf << LinearModelToJSON(0) << "}" << "\n";
} else {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
}
} else { } else {
str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n'; str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
} }
return str_buf.str();
}
std::string Tree::LinearModelToJSON(int index) const {
std::stringstream str_buf;
Common::C_stringstream(str_buf);
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << "\"leaf_const\":" << leaf_const_[index] << "," << "\n";
int num_features = static_cast<int>(leaf_features_[index].size());
if (num_features > 0) {
str_buf << "\"leaf_features\":[";
for (int i = 0; i < num_features - 1; ++i) {
str_buf << leaf_features_[index][i] << ", ";
}
str_buf << leaf_features_[index][num_features - 1] << "]" << ", " << "\n";
str_buf << "\"leaf_coeff\":[";
for (int i = 0; i < num_features - 1; ++i) {
str_buf << leaf_coeff_[index][i] << ", ";
}
str_buf << leaf_coeff_[index][num_features - 1] << "]" << "\n";
} else {
str_buf << "\"leaf_features\":[],\n";
str_buf << "\"leaf_coeff\":[]\n";
}
return str_buf.str(); return str_buf.str();
} }
...@@ -479,10 +507,14 @@ std::string Tree::NodeToJSON(int index) const { ...@@ -479,10 +507,14 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"leaf_index\":" << index << "," << '\n'; str_buf << "\"leaf_index\":" << index << "," << '\n';
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n'; str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n'; str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n'; if (is_linear_) {
str_buf << "\"leaf_count\":" << leaf_count_[index] << "," << '\n';
str_buf << LinearModelToJSON(index);
} else {
str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
}
str_buf << "}"; str_buf << "}";
} }
return str_buf.str(); return str_buf.str();
} }
......
...@@ -2793,3 +2793,28 @@ def test_reset_params_works_with_metric_num_class_and_boosting(): ...@@ -2793,3 +2793,28 @@ def test_reset_params_works_with_metric_num_class_and_boosting():
expected_params = dict(dataset_params, **booster_params) expected_params = dict(dataset_params, **booster_params)
assert bst.params == expected_params assert bst.params == expected_params
assert new_bst.params == expected_params assert new_bst.params == expected_params
def test_dump_model():
X, y = load_breast_cancer(return_X_y=True)
train_data = lgb.Dataset(X, label=y)
params = {
"objective": "binary",
"verbose": -1
}
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0))
assert "leaf_features" not in dumped_model_str
assert "leaf_coeff" not in dumped_model_str
assert "leaf_const" not in dumped_model_str
assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str
params['linear_tree'] = True
train_data = lgb.Dataset(X, label=y)
bst = lgb.train(params, train_data, num_boost_round=5)
dumped_model_str = str(bst.dump_model(5, 0))
assert "leaf_features" in dumped_model_str
assert "leaf_coeff" in dumped_model_str
assert "leaf_const" in dumped_model_str
assert "leaf_value" in dumped_model_str
assert "leaf_count" in dumped_model_str
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment