"tests/vscode:/vscode.git/clone" did not exist on "d677d6c64754d809cebbc841e745a0db320252b4"
Commit 36732f23 authored by Scott Lundberg's avatar Scott Lundberg Committed by Guolin Ke
Browse files

Explain individual predictions using SHAP value feature attributions (#825)

* Explain individual predictions using SHAP value feature attributions

* Address code review
parent 3d6c4f35
......@@ -186,9 +186,12 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* only used in prediction task
* Set to `true` will only predict the raw scores.
* Set to `false` will transformed score
* `predict_leaf_index `, default=`false`, type=bool, alias=`leaf_index `,`is_predict_leaf_index `
* `predict_leaf_index`, default=`false`, type=bool, alias=`leaf_index`,`is_predict_leaf_index`
* only used in prediction task
* Set to `true` to predict with leaf index of all trees
* `predict_contrib`, default=`false`, type=bool, alias=`contrib`,`is_predict_contrib`
* only used in prediction task
* Set to `true` to estimate [SHAP values](https://arxiv.org/abs/1706.06060), which represent how each feature contributed to each prediction. Produces number of features + 1 values where the last value is the expected value of the model output over the training data.
* `bin_construct_sample_cnt`, default=`200000`, type=int
* Number of data that sampled to construct histogram bins.
* Will give better training result when set this larger. But will increase data loading time.
......
......@@ -109,7 +109,7 @@ public:
*/
virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
virtual int NumPredictOneRow(int num_iteration, int is_pred_leaf) const = 0;
virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
/*!
* \brief Prediction for one record, not sigmoid transform
......@@ -137,6 +137,15 @@ public:
virtual void PredictLeafIndex(
const double* features, double* output) const = 0;
/*!
* \brief Feature contributions for the model's prediction of one record
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/
virtual void PredictContrib(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0;
/*!
* \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
......
......@@ -27,6 +27,7 @@ typedef void* BoosterHandle;
#define C_API_PREDICT_NORMAL (0)
#define C_API_PREDICT_RAW_SCORE (1)
#define C_API_PREDICT_LEAF_INDEX (2)
#define C_API_PREDICT_CONTRIB (3)
/*!
* \brief get string message of the last error
......
......@@ -117,6 +117,7 @@ public:
bool enable_load_from_binary_file = true;
int bin_construct_sample_cnt = 200000;
bool is_predict_leaf_index = false;
bool is_predict_contrib = false;
bool is_predict_raw_score = false;
int min_data_in_leaf = 20;
int min_data_in_bin = 5;
......@@ -420,6 +421,8 @@ struct ParameterAlias {
{ "predict_leaf_index", "is_predict_leaf_index" },
{ "raw_score", "is_predict_raw_score" },
{ "leaf_index", "is_predict_leaf_index" },
{ "contrib", "is_predict_contrib" },
{ "predict_contrib", "is_predict_contrib" },
{ "min_split_gain", "min_gain_to_split" },
{ "topk", "top_k" },
{ "reg_alpha", "lambda_l1" },
......@@ -458,7 +461,7 @@ struct ParameterAlias {
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing",
"init_score_file", "valid_init_score_file"
"init_score_file", "valid_init_score_file", "is_predict_contrib"
});
std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) {
......
......@@ -91,6 +91,34 @@ public:
inline double Predict(const double* feature_values) const;
inline int PredictLeafIndex(const double* feature_values) const;
inline void PredictContrib(const double* feature_values, int num_features, double* output) const;
inline double ExpectedValue(int node = 0) const;
inline int MaxDepth() const;
/*!
* \brief Used by TreeSHAP for data we keep about our decision path
*/
struct PathElement {
int feature_index;
double zero_fraction;
double one_fraction;
// note that pweight is included for convenience and is not tied with the other attributes,
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
double pweight;
PathElement() {}
PathElement(int i, double z, double o, double w) : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
};
/*! \brief Polynomial time algorithm for SHAP values (https://arxiv.org/abs/1706.06060) */
inline void TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const;
/*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; }
......@@ -102,6 +130,9 @@ public:
inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
/*! \brief Get the number of data points that fall at or below this node*/
inline int data_count(int node = 0) const { return node >= 0 ? internal_count_[node] : leaf_count_[~node]; }
/*!
* \brief Shrinkage for the tree's output
* shrinkage rate (a.k.a learning rate) is used to tune the traning process
......@@ -230,6 +261,16 @@ private:
/*! \brief Serialize one node to if-else statement*/
inline std::string NodeToIfElse(int index, bool is_predict_leaf_index);
/*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/
inline static void ExtendPath(PathElement *unique_path, int unique_depth,
double zero_fraction, double one_fraction, int feature_index);
/*! \brief Undo a previous extension of the decision path for TreeSHAP*/
inline static void UnwindPath(PathElement *unique_path, int unique_depth, int path_index);
/*! determine what the total permuation weight would be if we unwound a previous extension in the decision path*/
inline static double UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index);
/*! \brief Number of max leaves*/
int max_leaves_;
/*! \brief Number of current levas*/
......@@ -286,6 +327,145 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {
}
}
inline void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
double zero_fraction, double one_fraction, int feature_index) {
unique_path[unique_depth].feature_index = feature_index;
unique_path[unique_depth].zero_fraction = zero_fraction;
unique_path[unique_depth].one_fraction = one_fraction;
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
for (int i = unique_depth-1; i >= 0; i--) {
unique_path[i+1].pweight += one_fraction*unique_path[i].pweight*(i+1)
/ static_cast<double>(unique_depth+1);
unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth-i)
/ static_cast<double>(unique_depth+1);
}
}
inline void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = unique_path[i].pweight;
unique_path[i].pweight = next_one_portion*(unique_depth+1)
/ static_cast<double>((i+1)*one_fraction);
next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth-i)
/ static_cast<double>(unique_depth+1);
} else {
unique_path[i].pweight = (unique_path[i].pweight*(unique_depth+1))
/ static_cast<double>(zero_fraction*(unique_depth-i));
}
}
for (int i = path_index; i < unique_depth; ++i) {
unique_path[i].feature_index = unique_path[i+1].feature_index;
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
}
}
inline double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
double total = 0;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = next_one_portion*(unique_depth+1)
/ static_cast<double>((i+1)*one_fraction);
total += tmp;
next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth-i)
/ static_cast<double>(unique_depth+1));
} else {
total += (unique_path[i].pweight/zero_fraction)/((unique_depth-i)
/ static_cast<double>(unique_depth+1));
}
}
return total;
}
// recursive computation of SHAP values for a decision tree
inline void Tree::TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const {
// extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth;
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path+unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
const int split_index = split_feature_[node];
// leaf node
if (node < 0) {
for (int i = 1; i <= unique_depth; ++i) {
const double w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i];
phi[el.feature_index] += w*(el.one_fraction-el.zero_fraction)*leaf_value_[~node];
}
// internal node
} else {
const int hot_index = Decision(feature_values[split_index], node);
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
const double w = data_count(node);
const double hot_zero_fraction = data_count(hot_index)/w;
const double cold_zero_fraction = data_count(cold_index)/w;
double incoming_zero_fraction = 1;
double incoming_one_fraction = 1;
// see if we have already split on this feature,
// if so we undo that split so we can redo it for this node
int path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (unique_path[path_index].feature_index == split_index) break;
}
if (path_index != unique_depth+1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
incoming_one_fraction = unique_path[path_index].one_fraction;
UnwindPath(unique_path, unique_depth, path_index);
unique_depth -= 1;
}
TreeSHAP(feature_values, phi, hot_index, unique_depth+1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index);
TreeSHAP(feature_values, phi, cold_index, unique_depth+1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index);
}
}
inline void Tree::PredictContrib(const double* feature_values, int num_features, double *output) const {
output[num_features] += ExpectedValue();
// Run the recursion with preallocated space for the unique path data
const int max_path_len = MaxDepth()+1;
PathElement *unique_path_data = new PathElement[(max_path_len*(max_path_len+1))/2];
TreeSHAP(feature_values, output, 0, 0, unique_path_data, 1, 1, -1);
delete[] unique_path_data;
}
inline double Tree::ExpectedValue(int node) const {
if (node >= 0) {
const int l = left_child_[node];
const int r = right_child_[node];
return (data_count(l)*ExpectedValue(l) + data_count(r)*ExpectedValue(r))/data_count(node);
} else {
return LeafOutput(~node);
}
}
inline int Tree::MaxDepth() const {
int max_depth = 0;
for (int i = 0; i < num_leaves(); ++i) {
if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
}
return max_depth;
}
inline int Tree::GetLeaf(const double* feature_values) const {
int node = 0;
if (has_categorical_) {
......
......@@ -111,7 +111,7 @@ void Application::LoadData() {
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
// need to continue training
if (boosting_->NumberOfTotalModel() > 0) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, -1, -1));
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction();
}
......@@ -236,8 +236,9 @@ void Application::Train() {
void Application::Predict() {
// create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index, config_.io_config.pred_early_stop,
config_.io_config.pred_early_stop_freq, config_.io_config.pred_early_stop_margin);
config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib,
config_.io_config.pred_early_stop, config_.io_config.pred_early_stop_freq,
config_.io_config.pred_early_stop_margin);
predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finished prediction");
......
......@@ -28,10 +28,11 @@ public:
* \param boosting Input boosting model
* \param num_iteration Number of boosting round
* \param is_raw_score True if need to predict result with raw score
* \param is_predict_leaf_index True if output leaf index instead of prediction score
* \param is_predict_leaf_index True to output leaf index instead of prediction score
* \param is_predict_contrib True to output feature contributions instead of prediction score
*/
Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index,
bool is_raw_score, bool is_predict_leaf_index, bool is_predict_contrib,
bool early_stop, int early_stop_freq, double early_stop_margin) {
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
......@@ -53,7 +54,7 @@ public:
}
boosting->InitPredict(num_iteration);
boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index);
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
......@@ -66,6 +67,15 @@ public:
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
} else if (is_predict_contrib) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
// get result for leaf index
boosting_->PredictContrib(predict_buf_[tid].data(), output, &early_stop_);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
} else {
if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
......
......@@ -739,6 +739,27 @@ const double* GBDT::GetTrainingScore(int64_t* out_len) {
return train_score_updater_->score();
}
void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
int early_stop_round_counter = 0;
// set zero
const int num_features = max_feature_idx_+1;
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features+1));
for (int i = 0; i < num_iteration_for_pred_; ++i) {
// predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) {
models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features+1));
}
// check early stopping
++early_stop_round_counter;
if (early_stop->round_period == early_stop_round_counter) {
if (early_stop->callback_function(output, num_tree_per_iteration_)) {
return;
}
early_stop_round_counter = 0;
}
}
}
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
......
......@@ -133,7 +133,7 @@ public:
*/
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override;
inline int NumPredictOneRow(int num_iteration, int is_pred_leaf) const override {
inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override {
int num_preb_in_one_row = num_class_;
if (is_pred_leaf) {
int max_iteration = GetCurrentIteration();
......@@ -142,6 +142,8 @@ public:
} else {
num_preb_in_one_row *= max_iteration;
}
} else if (is_pred_contrib) {
num_preb_in_one_row = max_feature_idx_ + 2; // +1 for 0-based indexing, +1 for baseline
}
return num_preb_in_one_row;
}
......@@ -154,6 +156,9 @@ public:
void PredictLeafIndex(const double* features, double* output) const override;
void PredictContrib(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop) const override;
/*!
* \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all
......
......@@ -178,17 +178,20 @@ public:
std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false;
bool is_raw_score = false;
bool is_predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
is_predict_contrib = true;
} else {
is_raw_score = false;
}
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf);
int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, is_predict_contrib);
auto pred_fun = predictor.GetPredictFunction();
OMP_INIT_EX();
#pragma omp parallel for schedule(static)
......@@ -209,14 +212,17 @@ public:
std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false;
bool is_raw_score = false;
bool is_predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
is_predict_contrib = true;
} else {
is_raw_score = false;
}
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf,
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header);
......@@ -998,7 +1004,7 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX));
num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB));
API_END();
}
......
......@@ -257,6 +257,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
GetBool(params, "is_predict_contrib", &is_predict_contrib);
GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment