Commit 61fb5ea2 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a multi-thread bug in pred_contrib

parent 93237586
...@@ -257,8 +257,9 @@ public: ...@@ -257,8 +257,9 @@ public:
/*! /*!
* \brief Initial work for the prediction * \brief Initial work for the prediction
* \param num_iteration number of used iteration * \param num_iteration number of used iteration
* \param is_pred_contrib
*/ */
virtual void InitPredict(int num_iteration) = 0; virtual void InitPredict(int num_iteration, bool is_pred_contrib) = 0;
/*! /*!
* \brief Name of submodel * \brief Name of submodel
......
...@@ -203,6 +203,8 @@ public: ...@@ -203,6 +203,8 @@ public:
(*decision_type) |= (input << 2); (*decision_type) |= (input << 2);
} }
void RecomputeMaxDepth();
private: private:
std::string NumericalDecisionIfElse(int node) const; std::string NumericalDecisionIfElse(int node) const;
...@@ -313,8 +315,6 @@ private: ...@@ -313,8 +315,6 @@ private:
double ExpectedValue() const; double ExpectedValue() const;
int MaxDepth();
/*! \brief This is used fill in leaf_depth_ after reloading a model*/ /*! \brief This is used fill in leaf_depth_ after reloading a model*/
inline void RecomputeLeafDepths(int node = 0, int depth = 0); inline void RecomputeLeafDepths(int node = 0, int depth = 0);
...@@ -390,6 +390,7 @@ private: ...@@ -390,6 +390,7 @@ private:
/*! \brief Depth for leaves */ /*! \brief Depth for leaves */
std::vector<int> leaf_depth_; std::vector<int> leaf_depth_;
double shrinkage_; double shrinkage_;
int max_depth_;
}; };
inline void Tree::Split(int leaf, int feature, int real_feature, inline void Tree::Split(int leaf, int feature, int real_feature,
...@@ -468,10 +469,10 @@ inline void Tree::PredictContrib(const double* feature_values, int num_features, ...@@ -468,10 +469,10 @@ inline void Tree::PredictContrib(const double* feature_values, int num_features,
output[num_features] += ExpectedValue(); output[num_features] += ExpectedValue();
// Run the recursion with preallocated space for the unique path data // Run the recursion with preallocated space for the unique path data
if (num_leaves_ > 1) { if (num_leaves_ > 1) {
const int max_path_len = MaxDepth() + 1; CHECK(max_depth_ >= 0);
PathElement *unique_path_data = new PathElement[(max_path_len*(max_path_len + 1)) / 2]; const int max_path_len = max_depth_ + 1;
TreeSHAP(feature_values, output, 0, 0, unique_path_data, 1, 1, -1); std::vector<PathElement> unique_path_data(max_path_len*(max_path_len + 1) / 2);
delete[] unique_path_data; TreeSHAP(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1);
} }
} }
......
...@@ -55,7 +55,7 @@ public: ...@@ -55,7 +55,7 @@ public:
{ {
num_threads_ = omp_get_num_threads(); num_threads_ = omp_get_num_threads();
} }
boosting->InitPredict(num_iteration); boosting->InitPredict(num_iteration, is_predict_contrib);
boosting_ = boosting; boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib); num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1; num_feature_ = boosting_->MaxFeatureIdx() + 1;
......
...@@ -291,11 +291,17 @@ public: ...@@ -291,11 +291,17 @@ public:
*/ */
inline int NumberOfClasses() const override { return num_class_; } inline int NumberOfClasses() const override { return num_class_; }
inline void InitPredict(int num_iteration) override { inline void InitPredict(int num_iteration, bool is_pred_contrib) override {
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_tree_per_iteration_;
if (num_iteration > 0) { if (num_iteration > 0) {
num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_); num_iteration_for_pred_ = std::min(num_iteration, num_iteration_for_pred_);
} }
if (is_pred_contrib) {
#pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(models_.size()); ++i) {
models_[i]->RecomputeMaxDepth();
}
}
} }
inline double GetLeafValue(int tree_idx, int leaf_idx) const override { inline double GetLeafValue(int tree_idx, int leaf_idx) const override {
......
...@@ -41,6 +41,7 @@ Tree::Tree(int max_leaves) ...@@ -41,6 +41,7 @@ Tree::Tree(int max_leaves)
num_cat_ = 0; num_cat_ = 0;
cat_boundaries_.push_back(0); cat_boundaries_.push_back(0);
cat_boundaries_inner_.push_back(0); cat_boundaries_inner_.push_back(0);
max_depth_ = -1;
} }
Tree::~Tree() { Tree::~Tree() {
...@@ -584,6 +585,7 @@ Tree::Tree(const char* str, size_t* used_len) { ...@@ -584,6 +585,7 @@ Tree::Tree(const char* str, size_t* used_len) {
} else { } else {
shrinkage_ = 1.0f; shrinkage_ = 1.0f;
} }
max_depth_ = -1;
} }
void Tree::ExtendPath(PathElement *unique_path, int unique_depth, void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
...@@ -652,7 +654,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, ...@@ -652,7 +654,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
double parent_one_fraction, int parent_feature_index) const { double parent_one_fraction, int parent_feature_index) const {
// extend the unique path // extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth; PathElement* unique_path = parent_unique_path + unique_depth;
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path); if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction, ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index); parent_one_fraction, parent_feature_index);
...@@ -706,14 +708,18 @@ double Tree::ExpectedValue() const { ...@@ -706,14 +708,18 @@ double Tree::ExpectedValue() const {
return exp_value; return exp_value;
} }
int Tree::MaxDepth() { void Tree::RecomputeMaxDepth() {
if (leaf_depth_.size() == 0) RecomputeLeafDepths(); if (num_leaves_ == 1) {
if (num_leaves_ == 1) return 0; max_depth_ = 0;
int max_depth = 0; } else {
for (int i = 0; i < num_leaves(); ++i) { if (leaf_depth_.size() == 0) {
if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i]; RecomputeLeafDepths(0, 0);
}
max_depth_ = leaf_depth_[0];
for (int i = 1; i < num_leaves(); ++i) {
if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
}
} }
return max_depth;
} }
} // namespace LightGBM } // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment