".github/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "1d0d746e7f95ca654ed410f9da1f9161ce0fc7c1"
Commit 2c572a71 authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code for Tree.

parent 6d0eae0c
...@@ -60,7 +60,7 @@ public: ...@@ -60,7 +60,7 @@ public:
* \param real_feature Index of feature, the original index on data * \param real_feature Index of feature, the original index on data
* \param threshold_bin Threshold(bin) of split, use bitset to represent * \param threshold_bin Threshold(bin) of split, use bitset to represent
* \param num_threshold_bin size of threshold_bin * \param num_threshold_bin size of threshold_bin
* \param threshold * \param threshold
* \param left_value Model Left child output * \param left_value Model Left child output
* \param right_value Model Right child output * \param right_value Model Right child output
* \param left_cnt Count of left child * \param left_cnt Count of left child
...@@ -112,32 +112,6 @@ public: ...@@ -112,32 +112,6 @@ public:
inline void PredictContrib(const double* feature_values, int num_features, double* output) const; inline void PredictContrib(const double* feature_values, int num_features, double* output) const;
inline double ExpectedValue(int node = 0) const;
inline int MaxDepth() const;
/*!
* \brief Used by TreeSHAP for data we keep about our decision path
*/
struct PathElement {
int feature_index;
double zero_fraction;
double one_fraction;
// note that pweight is included for convenience and is not tied with the other attributes,
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
double pweight;
PathElement() {}
PathElement(int i, double z, double o, double w) : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
};
/*! \brief Polynomial time algorithm for SHAP values (https://arxiv.org/abs/1706.06060) */
inline void TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const;
/*! \brief Get Number of leaves*/ /*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; } inline int num_leaves() const { return num_leaves_; }
...@@ -150,7 +124,7 @@ public: ...@@ -150,7 +124,7 @@ public:
inline double split_gain(int split_idx) const { return split_gain_[split_idx]; } inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
/*! \brief Get the number of data points that fall at or below this node*/ /*! \brief Get the number of data points that fall at or below this node*/
inline int data_count(int node = 0) const { return node >= 0 ? internal_count_[node] : leaf_count_[~node]; } inline int data_count(int node) const { return node >= 0 ? internal_count_[node] : leaf_count_[~node]; }
/*! /*!
* \brief Shrinkage for the tree's output * \brief Shrinkage for the tree's output
...@@ -161,8 +135,7 @@ public: ...@@ -161,8 +135,7 @@ public:
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048) #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] *= rate; leaf_value_[i] *= rate;
if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; } if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; } else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; }
else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; }
} }
shrinkage_ *= rate; shrinkage_ *= rate;
} }
...@@ -183,13 +156,13 @@ public: ...@@ -183,13 +156,13 @@ public:
} }
/*! \brief Serialize this object to string*/ /*! \brief Serialize this object to string*/
std::string ToString(); std::string ToString() const;
/*! \brief Serialize this object to json*/ /*! \brief Serialize this object to json*/
std::string ToJSON(); std::string ToJSON() const;
/*! \brief Serialize this object to if-else statement*/ /*! \brief Serialize this object to if-else statement*/
std::string ToIfElse(int index, bool is_predict_leaf_index); std::string ToIfElse(int index, bool is_predict_leaf_index) const;
inline static bool IsZero(double fval) { inline static bool IsZero(double fval) {
if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) { if (fval > -kZeroAsMissingValueRange && fval <= kZeroAsMissingValueRange) {
...@@ -222,39 +195,9 @@ public: ...@@ -222,39 +195,9 @@ public:
private: private:
inline std::string NumericalDecisionIfElse(int node) { std::string NumericalDecisionIfElse(int node) const;
std::stringstream str_buf;
uint8_t missing_type = GetMissingType(decision_type_[node]);
bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
if (missing_type == 0 || (missing_type == 1 && default_left && kZeroAsMissingValueRange < threshold_[node])) {
str_buf << "if (fval <= " << threshold_[node] << ") {";
} else if (missing_type == 1) {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
}
} else {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
}
}
return str_buf.str();
}
inline std::string CategoricalDecisionIfElse(int node) const { std::string CategoricalDecisionIfElse(int node) const;
uint8_t missing_type = GetMissingType(decision_type_[node]);
std::stringstream str_buf;
if (missing_type == 2) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
}
str_buf << "if (int_fval >= 0 && int_fval == " << static_cast<int>(threshold_[node]) << ") {";
return str_buf.str();
}
inline int NumericalDecision(double fval, int node) const { inline int NumericalDecision(double fval, int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]); uint8_t missing_type = GetMissingType(decision_type_[node]);
...@@ -346,20 +289,46 @@ private: ...@@ -346,20 +289,46 @@ private:
inline int GetLeaf(const double* feature_values) const; inline int GetLeaf(const double* feature_values) const;
/*! \brief Serialize one node to json*/ /*! \brief Serialize one node to json*/
inline std::string NodeToJSON(int index); std::string NodeToJSON(int index) const;
/*! \brief Serialize one node to if-else statement*/ /*! \brief Serialize one node to if-else statement*/
inline std::string NodeToIfElse(int index, bool is_predict_leaf_index); std::string NodeToIfElse(int index, bool is_predict_leaf_index) const;
double ExpectedValue(int node = 0) const;
int MaxDepth() const;
/*!
* \brief Used by TreeSHAP for data we keep about our decision path
*/
struct PathElement {
int feature_index;
double zero_fraction;
double one_fraction;
// note that pweight is included for convenience and is not tied with the other attributes,
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
double pweight;
PathElement() {}
PathElement(int i, double z, double o, double w) : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
};
/*! \brief Polynomial time algorithm for SHAP values (https://arxiv.org/abs/1706.06060) */
void TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const;
/*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/ /*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/
inline static void ExtendPath(PathElement *unique_path, int unique_depth, static void ExtendPath(PathElement *unique_path, int unique_depth,
double zero_fraction, double one_fraction, int feature_index); double zero_fraction, double one_fraction, int feature_index);
/*! \brief Undo a previous extension of the decision path for TreeSHAP*/ /*! \brief Undo a previous extension of the decision path for TreeSHAP*/
inline static void UnwindPath(PathElement *unique_path, int unique_depth, int path_index); static void UnwindPath(PathElement *unique_path, int unique_depth, int path_index);
/*! determine what the total permuation weight would be if we unwound a previous extension in the decision path*/ /*! determine what the total permuation weight would be if we unwound a previous extension in the decision path*/
inline static double UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index); static double UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index);
/*! \brief Number of max leaves*/ /*! \brief Number of max leaves*/
int max_leaves_; int max_leaves_;
...@@ -453,143 +422,12 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const { ...@@ -453,143 +422,12 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {
} }
} }
inline void Tree::ExtendPath(PathElement *unique_path, int unique_depth, inline void Tree::PredictContrib(const double* feature_values, int num_features, double* output) const {
double zero_fraction, double one_fraction, int feature_index) {
unique_path[unique_depth].feature_index = feature_index;
unique_path[unique_depth].zero_fraction = zero_fraction;
unique_path[unique_depth].one_fraction = one_fraction;
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
for (int i = unique_depth-1; i >= 0; i--) {
unique_path[i+1].pweight += one_fraction*unique_path[i].pweight*(i+1)
/ static_cast<double>(unique_depth+1);
unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth-i)
/ static_cast<double>(unique_depth+1);
}
}
inline void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = unique_path[i].pweight;
unique_path[i].pweight = next_one_portion*(unique_depth+1)
/ static_cast<double>((i+1)*one_fraction);
next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth-i)
/ static_cast<double>(unique_depth+1);
} else {
unique_path[i].pweight = (unique_path[i].pweight*(unique_depth+1))
/ static_cast<double>(zero_fraction*(unique_depth-i));
}
}
for (int i = path_index; i < unique_depth; ++i) {
unique_path[i].feature_index = unique_path[i+1].feature_index;
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
}
}
inline double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
double total = 0;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = next_one_portion*(unique_depth+1)
/ static_cast<double>((i+1)*one_fraction);
total += tmp;
next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth-i)
/ static_cast<double>(unique_depth+1));
} else {
total += (unique_path[i].pweight/zero_fraction)/((unique_depth-i)
/ static_cast<double>(unique_depth+1));
}
}
return total;
}
// recursive computation of SHAP values for a decision tree
inline void Tree::TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const {
// extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth;
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path+unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
const int split_index = split_feature_[node];
// leaf node
if (node < 0) {
for (int i = 1; i <= unique_depth; ++i) {
const double w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i];
phi[el.feature_index] += w*(el.one_fraction-el.zero_fraction)*leaf_value_[~node];
}
// internal node
} else {
const int hot_index = Decision(feature_values[split_index], node);
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
const double w = data_count(node);
const double hot_zero_fraction = data_count(hot_index)/w;
const double cold_zero_fraction = data_count(cold_index)/w;
double incoming_zero_fraction = 1;
double incoming_one_fraction = 1;
// see if we have already split on this feature,
// if so we undo that split so we can redo it for this node
int path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (unique_path[path_index].feature_index == split_index) break;
}
if (path_index != unique_depth+1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
incoming_one_fraction = unique_path[path_index].one_fraction;
UnwindPath(unique_path, unique_depth, path_index);
unique_depth -= 1;
}
TreeSHAP(feature_values, phi, hot_index, unique_depth+1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index);
TreeSHAP(feature_values, phi, cold_index, unique_depth+1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index);
}
}
inline void Tree::PredictContrib(const double* feature_values, int num_features, double *output) const {
output[num_features] += ExpectedValue(); output[num_features] += ExpectedValue();
// Run the recursion with preallocated space for the unique path data // Run the recursion with preallocated space for the unique path data
const int max_path_len = MaxDepth()+1; const int max_path_len = MaxDepth() + 1;
PathElement *unique_path_data = new PathElement[(max_path_len*(max_path_len+1))/2]; std::vector<PathElement> unique_path_data((max_path_len*(max_path_len + 1)) / 2);
TreeSHAP(feature_values, output, 0, 0, unique_path_data, 1, 1, -1); TreeSHAP(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1);
delete[] unique_path_data;
}
inline double Tree::ExpectedValue(int node) const {
if (node >= 0) {
const int l = left_child_[node];
const int r = right_child_[node];
return (data_count(l)*ExpectedValue(l) + data_count(r)*ExpectedValue(r))/data_count(node);
} else {
return LeafOutput(~node);
}
}
inline int Tree::MaxDepth() const {
int max_depth = 0;
for (int i = 0; i < num_leaves(); ++i) {
if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
}
return max_depth;
} }
inline int Tree::GetLeaf(const double* feature_values) const { inline int Tree::GetLeaf(const double* feature_values) const {
......
...@@ -147,16 +147,16 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl ...@@ -147,16 +147,16 @@ void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, doubl
} }
void Tree::AddPredictionToScore(const Dataset* data, void Tree::AddPredictionToScore(const Dataset* data,
const data_size_t* used_data_indices, const data_size_t* used_data_indices,
data_size_t num_data, double* score) const { data_size_t num_data, double* score) const {
if (num_leaves_ <= 1) { if (num_leaves_ <= 1) {
if (leaf_value_[0] != 0.0f) { if (leaf_value_[0] != 0.0f) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) { for (data_size_t i = 0; i < num_data; ++i) {
score[used_data_indices[i]] += leaf_value_[0]; score[used_data_indices[i]] += leaf_value_[0];
} }
} }
return; return;
} }
std::vector<uint32_t> default_bins(num_leaves_ - 1); std::vector<uint32_t> default_bins(num_leaves_ - 1);
std::vector<uint32_t> max_bins(num_leaves_ - 1); std::vector<uint32_t> max_bins(num_leaves_ - 1);
...@@ -195,7 +195,7 @@ void Tree::AddPredictionToScore(const Dataset* data, ...@@ -195,7 +195,7 @@ void Tree::AddPredictionToScore(const Dataset* data,
#undef PredictionFun #undef PredictionFun
std::string Tree::ToString() { std::string Tree::ToString() const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "num_leaves=" << num_leaves_ << std::endl; str_buf << "num_leaves=" << num_leaves_ << std::endl;
str_buf << "num_cat=" << num_cat_ << std::endl; str_buf << "num_cat=" << num_cat_ << std::endl;
...@@ -224,14 +224,14 @@ std::string Tree::ToString() { ...@@ -224,14 +224,14 @@ std::string Tree::ToString() {
return str_buf.str(); return str_buf.str();
} }
std::string Tree::ToJSON() { std::string Tree::ToJSON() const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl; str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl;
str_buf << "\"num_cat\":" << num_cat_ << "," << std::endl; str_buf << "\"num_cat\":" << num_cat_ << "," << std::endl;
str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl; str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl;
if (num_leaves_ == 1) { if (num_leaves_ == 1) {
str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << std::endl; str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << std::endl;
} else { } else {
str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl; str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl;
} }
...@@ -239,7 +239,7 @@ std::string Tree::ToJSON() { ...@@ -239,7 +239,7 @@ std::string Tree::ToJSON() {
return str_buf.str(); return str_buf.str();
} }
std::string Tree::NodeToJSON(int index) { std::string Tree::NodeToJSON(int index) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) { if (index >= 0) {
...@@ -286,7 +286,41 @@ std::string Tree::NodeToJSON(int index) { ...@@ -286,7 +286,41 @@ std::string Tree::NodeToJSON(int index) {
return str_buf.str(); return str_buf.str();
} }
std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) { std::string Tree::NumericalDecisionIfElse(int node) const {
std::stringstream str_buf;
uint8_t missing_type = GetMissingType(decision_type_[node]);
bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
if (missing_type == 0 || (missing_type == 1 && default_left && kZeroAsMissingValueRange < threshold_[node])) {
str_buf << "if (fval <= " << threshold_[node] << ") {";
} else if (missing_type == 1) {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
}
} else {
if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
} else {
str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
}
}
return str_buf.str();
}
std::string Tree::CategoricalDecisionIfElse(int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]);
std::stringstream str_buf;
if (missing_type == 2) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
}
str_buf << "if (int_fval >= 0 && int_fval == " << static_cast<int>(threshold_[node]) << ") {";
return str_buf.str();
}
std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << "double PredictTree" << index; str_buf << "double PredictTree" << index;
if (is_predict_leaf_index) { if (is_predict_leaf_index) {
...@@ -307,7 +341,7 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) { ...@@ -307,7 +341,7 @@ std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
return str_buf.str(); return str_buf.str();
} }
std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) { std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const {
std::stringstream str_buf; std::stringstream str_buf;
str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2); str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
if (index >= 0) { if (index >= 0) {
...@@ -432,4 +466,134 @@ Tree::Tree(const std::string& str) { ...@@ -432,4 +466,134 @@ Tree::Tree(const std::string& str) {
} }
} }
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
double zero_fraction, double one_fraction, int feature_index) {
unique_path[unique_depth].feature_index = feature_index;
unique_path[unique_depth].zero_fraction = zero_fraction;
unique_path[unique_depth].one_fraction = one_fraction;
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
for (int i = unique_depth - 1; i >= 0; i--) {
unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
/ static_cast<double>(unique_depth + 1);
unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
/ static_cast<double>(unique_depth + 1);
}
}
void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
for (int i = unique_depth - 1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = unique_path[i].pweight;
unique_path[i].pweight = next_one_portion*(unique_depth + 1)
/ static_cast<double>((i + 1)*one_fraction);
next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
/ static_cast<double>(unique_depth + 1);
} else {
unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
/ static_cast<double>(zero_fraction*(unique_depth - i));
}
}
for (int i = path_index; i < unique_depth; ++i) {
unique_path[i].feature_index = unique_path[i + 1].feature_index;
unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
}
}
double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
double total = 0;
for (int i = unique_depth - 1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = next_one_portion*(unique_depth + 1)
/ static_cast<double>((i + 1)*one_fraction);
total += tmp;
next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
/ static_cast<double>(unique_depth + 1));
} else {
total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
/ static_cast<double>(unique_depth + 1));
}
}
return total;
}
// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const {
// extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth;
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
const int split_index = split_feature_[node];
// leaf node
if (node < 0) {
for (int i = 1; i <= unique_depth; ++i) {
const double w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i];
phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
}
// internal node
} else {
const int hot_index = Decision(feature_values[split_index], node);
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
const double w = data_count(node);
const double hot_zero_fraction = data_count(hot_index) / w;
const double cold_zero_fraction = data_count(cold_index) / w;
double incoming_zero_fraction = 1;
double incoming_one_fraction = 1;
// see if we have already split on this feature,
// if so we undo that split so we can redo it for this node
int path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (unique_path[path_index].feature_index == split_index) break;
}
if (path_index != unique_depth + 1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
incoming_one_fraction = unique_path[path_index].one_fraction;
UnwindPath(unique_path, unique_depth, path_index);
unique_depth -= 1;
}
TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index);
TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index);
}
}
double Tree::ExpectedValue(int node) const {
if (node >= 0) {
const int l = left_child_[node];
const int r = right_child_[node];
return (data_count(l)*ExpectedValue(l) + data_count(r)*ExpectedValue(r)) / data_count(node);
} else {
return LeafOutput(~node);
}
}
int Tree::MaxDepth() const {
int max_depth = 0;
for (int i = 0; i < num_leaves(); ++i) {
if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
}
return max_depth;
}
} // namespace LightGBM } // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment