Commit 45cbcb05 authored by allen's avatar allen Committed by Guolin Ke
Browse files

add internal_value_ in tree and output to txt model, so we can use for...

add internal_value_ in tree and output to txt model, so we can use for calculating feature importance when predicting (#80)
parent 952099d6
...@@ -145,6 +145,8 @@ private: ...@@ -145,6 +145,8 @@ private:
int* leaf_parent_; int* leaf_parent_;
/*! \brief Output of leaves */ /*! \brief Output of leaves */
double* leaf_value_; double* leaf_value_;
/*! \brief Output of internal nodes(save internal output for per inference feature importance calc) */
double* internal_value_;
/*! \brief Depth for leaves */ /*! \brief Depth for leaves */
int* leaf_depth_; int* leaf_depth_;
}; };
......
...@@ -28,6 +28,7 @@ Tree::Tree(int max_leaves) ...@@ -28,6 +28,7 @@ Tree::Tree(int max_leaves)
leaf_parent_ = new int[max_leaves_]; leaf_parent_ = new int[max_leaves_];
leaf_value_ = new double[max_leaves_]; leaf_value_ = new double[max_leaves_];
internal_value_ = new double[max_leaves_ - 1];
leaf_depth_ = new int[max_leaves_]; leaf_depth_ = new int[max_leaves_];
// root is in the depth 1 // root is in the depth 1
leaf_depth_[0] = 1; leaf_depth_[0] = 1;
...@@ -44,6 +45,7 @@ Tree::~Tree() { ...@@ -44,6 +45,7 @@ Tree::~Tree() {
if (threshold_ != nullptr) { delete[] threshold_; } if (threshold_ != nullptr) { delete[] threshold_; }
if (split_gain_ != nullptr) { delete[] split_gain_; } if (split_gain_ != nullptr) { delete[] split_gain_; }
if (leaf_value_ != nullptr) { delete[] leaf_value_; } if (leaf_value_ != nullptr) { delete[] leaf_value_; }
if (internal_value_ != nullptr) { delete[] internal_value_; }
if (leaf_depth_ != nullptr) { delete[] leaf_depth_; } if (leaf_depth_ != nullptr) { delete[] leaf_depth_; }
} }
...@@ -72,6 +74,8 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat ...@@ -72,6 +74,8 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
// update new leaves // update new leaves
leaf_parent_[leaf] = new_node_idx; leaf_parent_[leaf] = new_node_idx;
leaf_parent_[num_leaves_] = new_node_idx; leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change
internal_value_[new_node_idx] = leaf_value_[leaf];
leaf_value_[leaf] = left_value; leaf_value_[leaf] = left_value;
leaf_value_[num_leaves_] = right_value; leaf_value_[num_leaves_] = right_value;
// update leaf depth // update leaf depth
...@@ -125,6 +129,8 @@ std::string Tree::ToString() { ...@@ -125,6 +129,8 @@ std::string Tree::ToString() {
<< Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl; << Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl;
ss << "leaf_value=" ss << "leaf_value="
<< Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl; << Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl;
ss << "internal_value="
<< Common::ArrayToString<double>(internal_value_, num_leaves_ - 1, ' ') << std::endl;
ss << std::endl; ss << std::endl;
return ss.str(); return ss.str();
} }
...@@ -145,7 +151,8 @@ Tree::Tree(const std::string& str) { ...@@ -145,7 +151,8 @@ Tree::Tree(const std::string& str) {
if (key_vals.count("num_leaves") <= 0 || key_vals.count("split_feature") <= 0 if (key_vals.count("num_leaves") <= 0 || key_vals.count("split_feature") <= 0
|| key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0 || key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0
|| key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0 || key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0
|| key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0) { || key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0
|| key_vals.count("internal_value") <= 0) {
Log::Fatal("Tree model string format error"); Log::Fatal("Tree model string format error");
} }
...@@ -158,6 +165,7 @@ Tree::Tree(const std::string& str) { ...@@ -158,6 +165,7 @@ Tree::Tree(const std::string& str) {
split_gain_ = new double[num_leaves_ - 1]; split_gain_ = new double[num_leaves_ - 1];
leaf_parent_ = new int[num_leaves_]; leaf_parent_ = new int[num_leaves_];
leaf_value_ = new double[num_leaves_]; leaf_value_ = new double[num_leaves_];
internal_value_ = new double[num_leaves_ - 1];
split_feature_ = nullptr; split_feature_ = nullptr;
threshold_in_bin_ = nullptr; threshold_in_bin_ = nullptr;
...@@ -177,6 +185,8 @@ Tree::Tree(const std::string& str) { ...@@ -177,6 +185,8 @@ Tree::Tree(const std::string& str) {
num_leaves_ , leaf_parent_); num_leaves_ , leaf_parent_);
Common::StringToDoubleArray(key_vals["leaf_value"], ' ', Common::StringToDoubleArray(key_vals["leaf_value"], ' ',
num_leaves_ , leaf_value_); num_leaves_ , leaf_value_);
Common::StringToDoubleArray(key_vals["internal_value"], ' ',
num_leaves_ - 1 , internal_value_);
} }
} // namespace LightGBM } // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment