Commit 45cbcb05 authored by allen's avatar allen Committed by Guolin Ke
Browse files

add internal_value_ in tree and output to txt model, so we can use for...

add internal_value_ in tree and output to txt model, so we can use for calculating feature importance when predicting (#80)
parent 952099d6
......@@ -145,6 +145,8 @@ private:
int* leaf_parent_;
/*! \brief Output of leaves */
double* leaf_value_;
/*! \brief Output of internal nodes(save internal output for per inference feature importance calc) */
double* internal_value_;
/*! \brief Depth for leaves */
int* leaf_depth_;
};
......
......@@ -28,6 +28,7 @@ Tree::Tree(int max_leaves)
leaf_parent_ = new int[max_leaves_];
leaf_value_ = new double[max_leaves_];
internal_value_ = new double[max_leaves_ - 1];
leaf_depth_ = new int[max_leaves_];
// root is in the depth 1
leaf_depth_[0] = 1;
......@@ -44,6 +45,7 @@ Tree::~Tree() {
if (threshold_ != nullptr) { delete[] threshold_; }
if (split_gain_ != nullptr) { delete[] split_gain_; }
if (leaf_value_ != nullptr) { delete[] leaf_value_; }
if (internal_value_ != nullptr) { delete[] internal_value_; }
if (leaf_depth_ != nullptr) { delete[] leaf_depth_; }
}
......@@ -72,6 +74,8 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
// update new leaves
leaf_parent_[leaf] = new_node_idx;
leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change
internal_value_[new_node_idx] = leaf_value_[leaf];
leaf_value_[leaf] = left_value;
leaf_value_[num_leaves_] = right_value;
// update leaf depth
......@@ -125,6 +129,8 @@ std::string Tree::ToString() {
<< Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl;
ss << "leaf_value="
<< Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl;
ss << "internal_value="
<< Common::ArrayToString<double>(internal_value_, num_leaves_ - 1, ' ') << std::endl;
ss << std::endl;
return ss.str();
}
......@@ -145,7 +151,8 @@ Tree::Tree(const std::string& str) {
if (key_vals.count("num_leaves") <= 0 || key_vals.count("split_feature") <= 0
|| key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0
|| key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0
|| key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0) {
|| key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0
|| key_vals.count("internal_value") <= 0) {
Log::Fatal("Tree model string format error");
}
......@@ -158,6 +165,7 @@ Tree::Tree(const std::string& str) {
split_gain_ = new double[num_leaves_ - 1];
leaf_parent_ = new int[num_leaves_];
leaf_value_ = new double[num_leaves_];
internal_value_ = new double[num_leaves_ - 1];
split_feature_ = nullptr;
threshold_in_bin_ = nullptr;
......@@ -177,6 +185,8 @@ Tree::Tree(const std::string& str) {
num_leaves_ , leaf_parent_);
Common::StringToDoubleArray(key_vals["leaf_value"], ' ',
num_leaves_ , leaf_value_);
Common::StringToDoubleArray(key_vals["internal_value"], ' ',
num_leaves_ - 1 , internal_value_);
}
} // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment