Commit cc83cd67 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

support constant tree (one-leaf tree) (#851)

parent ecc8b8cd
...@@ -167,6 +167,12 @@ public: ...@@ -167,6 +167,12 @@ public:
shrinkage_ *= rate; shrinkage_ *= rate;
} }
inline void AsConstantTree(double val) {
num_leaves_ = 1;
shrinkage_ = 1.0f;
leaf_value_[0] = val;
}
/*! \brief Serialize this object to string*/ /*! \brief Serialize this object to string*/
std::string ToString(); std::string ToString();
...@@ -425,7 +431,7 @@ inline double Tree::Predict(const double* feature_values) const { ...@@ -425,7 +431,7 @@ inline double Tree::Predict(const double* feature_values) const {
int leaf = GetLeaf(feature_values); int leaf = GetLeaf(feature_values);
return LeafOutput(leaf); return LeafOutput(leaf);
} else { } else {
return 0.0f; return leaf_value_[0];
} }
} }
......
...@@ -473,7 +473,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -473,7 +473,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
auto label = train_data_->metadata().label(); auto label = train_data_->metadata().label();
double init_score = ObtainAutomaticInitialScore(objective_function_, label, num_data_); double init_score = ObtainAutomaticInitialScore(objective_function_, label, num_data_);
std::unique_ptr<Tree> new_tree(new Tree(2)); std::unique_ptr<Tree> new_tree(new Tree(2));
new_tree->Split(0, 0, 0, 0, 0, init_score, init_score, 0, 0, -1, MissingType::None, true); new_tree->AsConstantTree(init_score);
train_score_updater_->AddScore(init_score, 0); train_score_updater_->AddScore(init_score, 0);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(init_score, 0); score_updater->AddScore(init_score, 0);
...@@ -553,8 +553,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -553,8 +553,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
// only add default score one-time // only add default score one-time
if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) { if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
auto output = class_default_output_[cur_tree_id]; auto output = class_default_output_[cur_tree_id];
new_tree->Split(0, 0, 0, 0, 0, new_tree->AsConstantTree(output);
output, output, 0, 0, -1, MissingType::None, true);
train_score_updater_->AddScore(output, cur_tree_id); train_score_updater_->AddScore(output, cur_tree_id);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(output, cur_tree_id); score_updater->AddScore(output, cur_tree_id);
......
...@@ -127,8 +127,7 @@ public: ...@@ -127,8 +127,7 @@ public:
if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) { if (!class_need_train_[cur_tree_id] && models_.size() < static_cast<size_t>(num_tree_per_iteration_)) {
double output = class_default_output_[cur_tree_id]; double output = class_default_output_[cur_tree_id];
objective_function_->ConvertOutput(&output, &output); objective_function_->ConvertOutput(&output, &output);
new_tree->Split(0, 0, 0, 0, 0, new_tree->AsConstantTree(output);
output, output, 0, 0, -1, MissingType::None, true);
train_score_updater_->AddScore(output, cur_tree_id); train_score_updater_->AddScore(output, cur_tree_id);
for (auto& score_updater : valid_score_updater_) { for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(output, cur_tree_id); score_updater->AddScore(output, cur_tree_id);
......
...@@ -18,7 +18,6 @@ namespace LightGBM { ...@@ -18,7 +18,6 @@ namespace LightGBM {
Tree::Tree(int max_leaves) Tree::Tree(int max_leaves)
:max_leaves_(max_leaves) { :max_leaves_(max_leaves) {
num_leaves_ = 0;
left_child_.resize(max_leaves_ - 1); left_child_.resize(max_leaves_ - 1);
right_child_.resize(max_leaves_ - 1); right_child_.resize(max_leaves_ - 1);
split_feature_inner_.resize(max_leaves_ - 1); split_feature_inner_.resize(max_leaves_ - 1);
...@@ -36,6 +35,7 @@ Tree::Tree(int max_leaves) ...@@ -36,6 +35,7 @@ Tree::Tree(int max_leaves)
// root is in the depth 0 // root is in the depth 0
leaf_depth_[0] = 0; leaf_depth_[0] = 0;
num_leaves_ = 1; num_leaves_ = 1;
leaf_value_[0] = 0.0f;
leaf_parent_[0] = -1; leaf_parent_[0] = -1;
shrinkage_ = 1.0f; shrinkage_ = 1.0f;
num_cat_ = 0; num_cat_ = 0;
...@@ -195,8 +195,6 @@ std::string Tree::ToString() { ...@@ -195,8 +195,6 @@ std::string Tree::ToString() {
<< Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl;
str_buf << "right_child=" str_buf << "right_child="
<< Common::ArrayToString<int>(right_child_, num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<int>(right_child_, num_leaves_ - 1, ' ') << std::endl;
str_buf << "leaf_parent="
<< Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl;
str_buf << "leaf_value=" str_buf << "leaf_value="
<< Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl; << Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl;
str_buf << "leaf_count=" str_buf << "leaf_count="
...@@ -217,7 +215,7 @@ std::string Tree::ToJSON() { ...@@ -217,7 +215,7 @@ std::string Tree::ToJSON() {
str_buf << "\"num_cat\":" << num_cat_ << "," << std::endl; str_buf << "\"num_cat\":" << num_cat_ << "," << std::endl;
str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl; str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl;
if (num_leaves_ == 1) { if (num_leaves_ == 1) {
str_buf << "\"tree_structure\":" << NodeToJSON(-1) << std::endl; str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << std::endl;
} else { } else {
str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl; str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl;
} }
...@@ -264,7 +262,6 @@ std::string Tree::NodeToJSON(int index) { ...@@ -264,7 +262,6 @@ std::string Tree::NodeToJSON(int index) {
index = ~index; index = ~index;
str_buf << "{" << std::endl; str_buf << "{" << std::endl;
str_buf << "\"leaf_index\":" << index << "," << std::endl; str_buf << "\"leaf_index\":" << index << "," << std::endl;
str_buf << "\"leaf_parent\":" << leaf_parent_[index] << "," << std::endl;
str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << std::endl; str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << std::endl;
str_buf << "\"leaf_count\":" << leaf_count_[index] << std::endl; str_buf << "\"leaf_count\":" << leaf_count_[index] << std::endl;
str_buf << "}"; str_buf << "}";
...@@ -280,8 +277,8 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) { ...@@ -280,8 +277,8 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
str_buf << "Leaf"; str_buf << "Leaf";
} }
str_buf << "(const double* arr) { "; str_buf << "(const double* arr) { ";
if (num_leaves_ == 1) { if (num_leaves_ <= 1) {
str_buf << "return 0"; str_buf << "return " << leaf_value_[0] << ";";
} else { } else {
// use this for the missing value conversion // use this for the missing value conversion
str_buf << "double fval = 0.0f; "; str_buf << "double fval = 0.0f; ";
...@@ -350,6 +347,12 @@ Tree::Tree(const std::string& str) { ...@@ -350,6 +347,12 @@ Tree::Tree(const std::string& str) {
Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_); Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);
if (key_vals.count("leaf_value")) {
leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
} else {
Log::Fatal("Tree model string format error, should contain leaf_value field");
}
if (num_leaves_ <= 1) { return; } if (num_leaves_ <= 1) { return; }
if (key_vals.count("left_child")) { if (key_vals.count("left_child")) {
...@@ -376,12 +379,6 @@ Tree::Tree(const std::string& str) { ...@@ -376,12 +379,6 @@ Tree::Tree(const std::string& str) {
Log::Fatal("Tree model string format error, should contain threshold field"); Log::Fatal("Tree model string format error, should contain threshold field");
} }
if (key_vals.count("leaf_value")) {
leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
} else {
Log::Fatal("Tree model string format error, should contain leaf_value field");
}
if (key_vals.count("split_gain")) { if (key_vals.count("split_gain")) {
split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1); split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
} else { } else {
...@@ -406,12 +403,6 @@ Tree::Tree(const std::string& str) { ...@@ -406,12 +403,6 @@ Tree::Tree(const std::string& str) {
leaf_count_.resize(num_leaves_); leaf_count_.resize(num_leaves_);
} }
if (key_vals.count("leaf_parent")) {
leaf_parent_ = Common::StringToArray<int>(key_vals["leaf_parent"], ' ', num_leaves_);
} else {
leaf_parent_.resize(num_leaves_);
}
if (key_vals.count("decision_type")) { if (key_vals.count("decision_type")) {
decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1); decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1);
} else { } else {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment