Commit 36732f23 authored by Scott Lundberg's avatar Scott Lundberg Committed by Guolin Ke
Browse files

Explain individual predictions using SHAP value feature attributions (#825)

* Explain individual predictions using SHAP value feature attributions

* Address code review
parent 3d6c4f35
...@@ -37,13 +37,13 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -37,13 +37,13 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `huber`, [Huber loss](https://en.wikipedia.org/wiki/Huber_loss "Huber loss - Wikipedia") * `huber`, [Huber loss](https://en.wikipedia.org/wiki/Huber_loss "Huber loss - Wikipedia")
* `fair`, [Fair loss](https://www.kaggle.com/c/allstate-claims-severity/discussion/24520) * `fair`, [Fair loss](https://www.kaggle.com/c/allstate-claims-severity/discussion/24520)
* `poisson`, [Poisson regression](https://en.wikipedia.org/wiki/Poisson_regression "Poisson regression") * `poisson`, [Poisson regression](https://en.wikipedia.org/wiki/Poisson_regression "Poisson regression")
* `binary`, binary classification application * `binary`, binary classification application
* `lambdarank`, [lambdarank](https://pdfs.semanticscholar.org/fc9a/e09f9ced555558fdf1e997c0a5411fb51f15.pdf) application * `lambdarank`, [lambdarank](https://pdfs.semanticscholar.org/fc9a/e09f9ced555558fdf1e997c0a5411fb51f15.pdf) application
* The label should be `int` type in lambdarank tasks, and larger number represent the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect). * The label should be `int` type in lambdarank tasks, and larger number represent the higher relevance (e.g. 0:bad, 1:fair, 2:good, 3:perfect).
* `label_gain` can be used to set the gain(weight) of `int` label. * `label_gain` can be used to set the gain(weight) of `int` label.
* `multiclass`, multi-class classification application, should set `num_class` as well * `multiclass`, multi-class classification application, should set `num_class` as well
* `boosting`, default=`gbdt`, type=enum, options=`gbdt`,`rf`,`dart`,`goss`, alias=`boost`,`boosting_type` * `boosting`, default=`gbdt`, type=enum, options=`gbdt`,`rf`,`dart`,`goss`, alias=`boost`,`boosting_type`
* `gbdt`, traditional Gradient Boosting Decision Tree * `gbdt`, traditional Gradient Boosting Decision Tree
* `rf`, Random Forest * `rf`, Random Forest
* `dart`, [Dropouts meet Multiple Additive Regression Trees](https://arxiv.org/abs/1505.01866) * `dart`, [Dropouts meet Multiple Additive Regression Trees](https://arxiv.org/abs/1505.01866)
* `goss`, Gradient-based One-Side Sampling * `goss`, Gradient-based One-Side Sampling
...@@ -67,7 +67,7 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -67,7 +67,7 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `data`, data parallel tree learner * `data`, data parallel tree learner
* Refer to [Parallel Learning Guide](./Parallel-Learning-Guide.md) to get more details. * Refer to [Parallel Learning Guide](./Parallel-Learning-Guide.md) to get more details.
* `num_threads`, default=OpenMP_default, type=int, alias=`num_thread`,`nthread` * `num_threads`, default=OpenMP_default, type=int, alias=`num_thread`,`nthread`
* Number of threads for LightGBM. * Number of threads for LightGBM.
* For the best speed, set this to the number of **real CPU cores**, not the number of threads (most CPU using [hyper-threading](https://en.wikipedia.org/wiki/Hyper-threading) to generate 2 threads per CPU core). * For the best speed, set this to the number of **real CPU cores**, not the number of threads (most CPU using [hyper-threading](https://en.wikipedia.org/wiki/Hyper-threading) to generate 2 threads per CPU core).
* Do not set it too large if your dataset is small (do not use 64 threads for a dataset with 10,000 for instance). * Do not set it too large if your dataset is small (do not use 64 threads for a dataset with 10,000 for instance).
* Be aware a task manager or any similar CPU monitoring tool might report cores not being fully utilized. This is normal. * Be aware a task manager or any similar CPU monitoring tool might report cores not being fully utilized. This is normal.
...@@ -79,8 +79,8 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -79,8 +79,8 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
## Learning control parameters ## Learning control parameters
* `max_depth`, default=`-1`, type=int * `max_depth`, default=`-1`, type=int
* Limit the max depth for tree model. This is used to deal with overfit when #data is small. Tree still grow by leaf-wise. * Limit the max depth for tree model. This is used to deal with overfit when #data is small. Tree still grow by leaf-wise.
* `< 0` means no limit * `< 0` means no limit
* `min_data_in_leaf`, default=`20`, type=int, alias=`min_data_per_leaf` , `min_data` * `min_data_in_leaf`, default=`20`, type=int, alias=`min_data_per_leaf` , `min_data`
* Minimal number of data in one leaf. Can use this to deal with over-fit. * Minimal number of data in one leaf. Can use this to deal with over-fit.
* `min_sum_hessian_in_leaf`, default=`1e-3`, type=double, alias=`min_sum_hessian_per_leaf`, `min_sum_hessian`, `min_hessian` * `min_sum_hessian_in_leaf`, default=`1e-3`, type=double, alias=`min_sum_hessian_per_leaf`, `min_sum_hessian`, `min_hessian`
...@@ -104,11 +104,11 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -104,11 +104,11 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `early_stopping_round` , default=`0`, type=int, alias=`early_stopping_rounds`,`early_stopping` * `early_stopping_round` , default=`0`, type=int, alias=`early_stopping_rounds`,`early_stopping`
* Will stop training if one metric of one validation data doesn't improve in last `early_stopping_round` rounds. * Will stop training if one metric of one validation data doesn't improve in last `early_stopping_round` rounds.
* `lambda_l1` , default=`0`, type=double * `lambda_l1` , default=`0`, type=double
* l1 regularization * l1 regularization
* `lambda_l2` , default=`0`, type=double * `lambda_l2` , default=`0`, type=double
* l2 regularization * l2 regularization
* `min_gain_to_split` , default=`0`, type=double * `min_gain_to_split` , default=`0`, type=double
* The minimal gain to perform split * The minimal gain to perform split
* `drop_rate`, default=`0.1`, type=double * `drop_rate`, default=`0.1`, type=double
* only used in `dart` * only used in `dart`
* `skip_drop`, default=`0.5`, type=double * `skip_drop`, default=`0.5`, type=double
...@@ -186,22 +186,25 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -186,22 +186,25 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* only used in prediction task * only used in prediction task
* Set to `true` will only predict the raw scores. * Set to `true` will only predict the raw scores.
* Set to `false` will transformed score * Set to `false` will transformed score
* `predict_leaf_index `, default=`false`, type=bool, alias=`leaf_index `,`is_predict_leaf_index ` * `predict_leaf_index`, default=`false`, type=bool, alias=`leaf_index`,`is_predict_leaf_index`
* only used in prediction task * only used in prediction task
* Set to `true` to predict with leaf index of all trees * Set to `true` to predict with leaf index of all trees
* `predict_contrib`, default=`false`, type=bool, alias=`contrib`,`is_predict_contrib`
* only used in prediction task
* Set to `true` to estimate [SHAP values](https://arxiv.org/abs/1706.06060), which represent how each feature contributed to each prediction. Produces number of features + 1 values where the last value is the expected value of the model output over the training data.
* `bin_construct_sample_cnt`, default=`200000`, type=int * `bin_construct_sample_cnt`, default=`200000`, type=int
* Number of data that sampled to construct histogram bins. * Number of data that sampled to construct histogram bins.
* Will give better training result when set this larger. But will increase data loading time. * Will give better training result when set this larger. But will increase data loading time.
* Set this to larger value if data is very sparse. * Set this to larger value if data is very sparse.
* `num_iteration_predict`, default=`-1`, type=int * `num_iteration_predict`, default=`-1`, type=int
* only used in prediction task, used to how many trained iterations will be used in prediction. * only used in prediction task, used to how many trained iterations will be used in prediction.
* `<= 0` means no limit * `<= 0` means no limit
* `pred_early_stop`, default=`false`, type=bool * `pred_early_stop`, default=`false`, type=bool
* Set to `true` will use early-stopping to speed up the prediction. May affect the accuracy. * Set to `true` will use early-stopping to speed up the prediction. May affect the accuracy.
* `pred_early_stop_freq`, default=`10`, type=int * `pred_early_stop_freq`, default=`10`, type=int
* The frequency of checking early-stopping prediction. * The frequency of checking early-stopping prediction.
* `pred_early_stop_margin`, default=`10.0`, type=double * `pred_early_stop_margin`, default=`10.0`, type=double
* The Threshold of margin in early-stopping prediction. * The Threshold of margin in early-stopping prediction.
* `use_missing`, default=`true`, type=bool * `use_missing`, default=`true`, type=bool
* Set to `false` will disable the special handle of missing value. * Set to `false` will disable the special handle of missing value.
* `zero_as_missing`, default=`false`, type=bool * `zero_as_missing`, default=`false`, type=bool
...@@ -265,14 +268,14 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -265,14 +268,14 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
## Network parameters ## Network parameters
Following parameters are used for parallel learning, and only used for base(socket) version. Following parameters are used for parallel learning, and only used for base(socket) version.
* `num_machines`, default=`1`, type=int, alias=`num_machine` * `num_machines`, default=`1`, type=int, alias=`num_machine`
* Used for parallel learning, the number of machines for parallel learning application * Used for parallel learning, the number of machines for parallel learning application
* Need to set this in both socket and mpi version. * Need to set this in both socket and mpi version.
* `local_listen_port`, default=`12400`, type=int, alias=`local_port` * `local_listen_port`, default=`12400`, type=int, alias=`local_port`
* TCP listen port for local machines. * TCP listen port for local machines.
* Should allow this port in firewall setting before training. * Should allow this port in firewall setting before training.
* `time_out`, default=`120`, type=int * `time_out`, default=`120`, type=int
* Socket time-out in minutes. * Socket time-out in minutes.
* `machine_list_file`, default=`""`, type=string * `machine_list_file`, default=`""`, type=string
...@@ -285,8 +288,8 @@ Following parameters are used for parallel learning, and only used for base(sock ...@@ -285,8 +288,8 @@ Following parameters are used for parallel learning, and only used for base(sock
* OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform. * OpenCL platform ID. Usually each GPU vendor exposes one OpenCL platform.
* Default value is -1, using the system-wide default platform. * Default value is -1, using the system-wide default platform.
* `gpu_device_id`, default=`-1`, type=int * `gpu_device_id`, default=`-1`, type=int
* OpenCL device ID in the specified platform. Each GPU in the selected platform has a unique device ID. * OpenCL device ID in the specified platform. Each GPU in the selected platform has a unique device ID.
* Default value is -1, using the default device in the selected platform. * Default value is -1, using the default device in the selected platform.
* `gpu_use_dp`, default=`false`, type=bool * `gpu_use_dp`, default=`false`, type=bool
* Set to true to use double precision math on GPU (default using single precision). * Set to true to use double precision math on GPU (default using single precision).
...@@ -313,7 +316,7 @@ LightGBM support continued train with initial score. It uses an additional file ...@@ -313,7 +316,7 @@ LightGBM support continued train with initial score. It uses an additional file
... ...
``` ```
It means the initial score of first data is `0.5`, second is `-0.1`, and so on. The initial score file corresponds with data file line by line, and has per score per line. And if the name of data file is "train.txt", the initial score file should be named as "train.txt.init" and in the same folder as the data file. And LightGBM will auto load initial score file if it exists. It means the initial score of first data is `0.5`, second is `-0.1`, and so on. The initial score file corresponds with data file line by line, and has per score per line. And if the name of data file is "train.txt", the initial score file should be named as "train.txt.init" and in the same folder as the data file. And LightGBM will auto load initial score file if it exists.
### Weight data ### Weight data
......
...@@ -109,7 +109,7 @@ public: ...@@ -109,7 +109,7 @@ public:
*/ */
virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0; virtual void GetPredictAt(int data_idx, double* result, int64_t* out_len) = 0;
virtual int NumPredictOneRow(int num_iteration, int is_pred_leaf) const = 0; virtual int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const = 0;
/*! /*!
* \brief Prediction for one record, not sigmoid transform * \brief Prediction for one record, not sigmoid transform
...@@ -128,7 +128,7 @@ public: ...@@ -128,7 +128,7 @@ public:
*/ */
virtual void Predict(const double* features, double* output, virtual void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0; const PredictionEarlyStopInstance* early_stop) const = 0;
/*! /*!
* \brief Prediction for one record with leaf index * \brief Prediction for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
...@@ -137,6 +137,15 @@ public: ...@@ -137,6 +137,15 @@ public:
virtual void PredictLeafIndex( virtual void PredictLeafIndex(
const double* features, double* output) const = 0; const double* features, double* output) const = 0;
/*!
* \brief Feature contributions for the model's prediction of one record
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/
virtual void PredictContrib(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0;
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all * \param num_iteration Number of iterations that want to dump, -1 means dump all
...@@ -205,7 +214,7 @@ public: ...@@ -205,7 +214,7 @@ public:
* \return Number of weak sub-models * \return Number of weak sub-models
*/ */
virtual int NumberOfTotalModel() const = 0; virtual int NumberOfTotalModel() const = 0;
/*! /*!
* \brief Get number of trees per iteration * \brief Get number of trees per iteration
* \return Number of trees per iteration * \return Number of trees per iteration
...@@ -226,7 +235,7 @@ public: ...@@ -226,7 +235,7 @@ public:
* \param num_iteration number of used iteration * \param num_iteration number of used iteration
*/ */
virtual void InitPredict(int num_iteration) = 0; virtual void InitPredict(int num_iteration) = 0;
/*! /*!
* \brief Name of submodel * \brief Name of submodel
*/ */
......
...@@ -27,6 +27,7 @@ typedef void* BoosterHandle; ...@@ -27,6 +27,7 @@ typedef void* BoosterHandle;
#define C_API_PREDICT_NORMAL (0) #define C_API_PREDICT_NORMAL (0)
#define C_API_PREDICT_RAW_SCORE (1) #define C_API_PREDICT_RAW_SCORE (1)
#define C_API_PREDICT_LEAF_INDEX (2) #define C_API_PREDICT_LEAF_INDEX (2)
#define C_API_PREDICT_CONTRIB (3)
/*! /*!
* \brief get string message of the last error * \brief get string message of the last error
...@@ -54,7 +55,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -54,7 +55,7 @@ LIGHTGBM_C_EXPORT int LGBM_DatasetCreateFromFile(const char* filename,
/*! /*!
* \brief create a empty dataset by sampling data. * \brief create a empty dataset by sampling data.
* \param sample_data sampled data, grouped by the column. * \param sample_data sampled data, grouped by the column.
* \param sample_indices indices of sampled data. * \param sample_indices indices of sampled data.
* \param ncol number columns * \param ncol number columns
* \param num_per_col Size of each sampling column * \param num_per_col Size of each sampling column
* \param num_sample_row Number of sampled rows * \param num_sample_row Number of sampled rows
......
...@@ -117,6 +117,7 @@ public: ...@@ -117,6 +117,7 @@ public:
bool enable_load_from_binary_file = true; bool enable_load_from_binary_file = true;
int bin_construct_sample_cnt = 200000; int bin_construct_sample_cnt = 200000;
bool is_predict_leaf_index = false; bool is_predict_leaf_index = false;
bool is_predict_contrib = false;
bool is_predict_raw_score = false; bool is_predict_raw_score = false;
int min_data_in_leaf = 20; int min_data_in_leaf = 20;
int min_data_in_bin = 5; int min_data_in_bin = 5;
...@@ -127,7 +128,7 @@ public: ...@@ -127,7 +128,7 @@ public:
* And add an prefix "name:" while using column name */ * And add an prefix "name:" while using column name */
std::string label_column = ""; std::string label_column = "";
/*! \brief Index or column name of weight, < 0 means not used /*! \brief Index or column name of weight, < 0 means not used
* And add an prefix "name:" while using column name * And add an prefix "name:" while using column name
* Note: when using Index, it doesn't count the label index */ * Note: when using Index, it doesn't count the label index */
std::string weight_column = ""; std::string weight_column = "";
/*! \brief Index or column name of group/query id, < 0 means not used /*! \brief Index or column name of group/query id, < 0 means not used
...@@ -417,16 +418,18 @@ struct ParameterAlias { ...@@ -417,16 +418,18 @@ struct ParameterAlias {
{ "cat_column", "categorical_column" }, { "cat_column", "categorical_column" },
{ "cat_feature", "categorical_column" }, { "cat_feature", "categorical_column" },
{ "predict_raw_score", "is_predict_raw_score" }, { "predict_raw_score", "is_predict_raw_score" },
{ "predict_leaf_index", "is_predict_leaf_index" }, { "predict_leaf_index", "is_predict_leaf_index" },
{ "raw_score", "is_predict_raw_score" }, { "raw_score", "is_predict_raw_score" },
{ "leaf_index", "is_predict_leaf_index" }, { "leaf_index", "is_predict_leaf_index" },
{ "contrib", "is_predict_contrib" },
{ "predict_contrib", "is_predict_contrib" },
{ "min_split_gain", "min_gain_to_split" }, { "min_split_gain", "min_gain_to_split" },
{ "topk", "top_k" }, { "topk", "top_k" },
{ "reg_alpha", "lambda_l1" }, { "reg_alpha", "lambda_l1" },
{ "reg_lambda", "lambda_l2" }, { "reg_lambda", "lambda_l2" },
{ "num_classes", "num_class" }, { "num_classes", "num_class" },
{ "unbalanced_sets", "is_unbalance" }, { "unbalanced_sets", "is_unbalance" },
{ "bagging_fraction_seed", "bagging_seed" }, { "bagging_fraction_seed", "bagging_seed" },
{ "num_boost_round", "num_iterations" } { "num_boost_round", "num_iterations" }
}); });
const std::unordered_set<std::string> parameter_set({ const std::unordered_set<std::string> parameter_set({
...@@ -453,12 +456,12 @@ struct ParameterAlias { ...@@ -453,12 +456,12 @@ struct ParameterAlias {
"boost_from_average", "max_position", "label_gain", "boost_from_average", "max_position", "label_gain",
"metric", "metric_freq", "time_out", "metric", "metric_freq", "time_out",
"gpu_platform_id", "gpu_device_id", "gpu_use_dp", "gpu_platform_id", "gpu_device_id", "gpu_use_dp",
"convert_model", "convert_model_language", "convert_model", "convert_model_language",
"feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames", "feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames",
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file", "snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta", "max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing", "histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing",
"init_score_file", "valid_init_score_file" "init_score_file", "valid_init_score_file", "is_predict_contrib"
}); });
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
......
...@@ -91,6 +91,34 @@ public: ...@@ -91,6 +91,34 @@ public:
inline double Predict(const double* feature_values) const; inline double Predict(const double* feature_values) const;
inline int PredictLeafIndex(const double* feature_values) const; inline int PredictLeafIndex(const double* feature_values) const;
inline void PredictContrib(const double* feature_values, int num_features, double* output) const;
inline double ExpectedValue(int node = 0) const;
inline int MaxDepth() const;
/*!
* \brief Used by TreeSHAP for data we keep about our decision path
*/
struct PathElement {
int feature_index;
double zero_fraction;
double one_fraction;
// note that pweight is included for convenience and is not tied with the other attributes,
// the pweight of the i'th path element is the permuation weight of paths with i-1 ones in them
double pweight;
PathElement() {}
PathElement(int i, double z, double o, double w) : feature_index(i), zero_fraction(z), one_fraction(o), pweight(w) {}
};
/*! \brief Polynomial time algorithm for SHAP values (https://arxiv.org/abs/1706.06060) */
inline void TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const;
/*! \brief Get Number of leaves*/ /*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; } inline int num_leaves() const { return num_leaves_; }
...@@ -102,6 +130,9 @@ public: ...@@ -102,6 +130,9 @@ public:
inline double split_gain(int split_idx) const { return split_gain_[split_idx]; } inline double split_gain(int split_idx) const { return split_gain_[split_idx]; }
/*! \brief Get the number of data points that fall at or below this node*/
inline int data_count(int node = 0) const { return node >= 0 ? internal_count_[node] : leaf_count_[~node]; }
/*! /*!
* \brief Shrinkage for the tree's output * \brief Shrinkage for the tree's output
* shrinkage rate (a.k.a learning rate) is used to tune the traning process * shrinkage rate (a.k.a learning rate) is used to tune the traning process
...@@ -111,7 +142,7 @@ public: ...@@ -111,7 +142,7 @@ public:
#pragma omp parallel for schedule(static, 512) if (num_leaves_ >= 1024) #pragma omp parallel for schedule(static, 512) if (num_leaves_ >= 1024)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] *= rate; leaf_value_[i] *= rate;
if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; } if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; }
else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; } else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; }
} }
shrinkage_ *= rate; shrinkage_ *= rate;
...@@ -230,6 +261,16 @@ private: ...@@ -230,6 +261,16 @@ private:
/*! \brief Serialize one node to if-else statement*/ /*! \brief Serialize one node to if-else statement*/
inline std::string NodeToIfElse(int index, bool is_predict_leaf_index); inline std::string NodeToIfElse(int index, bool is_predict_leaf_index);
/*! \brief Extend our decision path with a fraction of one and zero extensions for TreeSHAP*/
inline static void ExtendPath(PathElement *unique_path, int unique_depth,
double zero_fraction, double one_fraction, int feature_index);
/*! \brief Undo a previous extension of the decision path for TreeSHAP*/
inline static void UnwindPath(PathElement *unique_path, int unique_depth, int path_index);
/*! determine what the total permuation weight would be if we unwound a previous extension in the decision path*/
inline static double UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index);
/*! \brief Number of max leaves*/ /*! \brief Number of max leaves*/
int max_leaves_; int max_leaves_;
/*! \brief Number of current levas*/ /*! \brief Number of current levas*/
...@@ -286,6 +327,145 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const { ...@@ -286,6 +327,145 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {
} }
} }
inline void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
double zero_fraction, double one_fraction, int feature_index) {
unique_path[unique_depth].feature_index = feature_index;
unique_path[unique_depth].zero_fraction = zero_fraction;
unique_path[unique_depth].one_fraction = one_fraction;
unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
for (int i = unique_depth-1; i >= 0; i--) {
unique_path[i+1].pweight += one_fraction*unique_path[i].pweight*(i+1)
/ static_cast<double>(unique_depth+1);
unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth-i)
/ static_cast<double>(unique_depth+1);
}
}
inline void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = unique_path[i].pweight;
unique_path[i].pweight = next_one_portion*(unique_depth+1)
/ static_cast<double>((i+1)*one_fraction);
next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth-i)
/ static_cast<double>(unique_depth+1);
} else {
unique_path[i].pweight = (unique_path[i].pweight*(unique_depth+1))
/ static_cast<double>(zero_fraction*(unique_depth-i));
}
}
for (int i = path_index; i < unique_depth; ++i) {
unique_path[i].feature_index = unique_path[i+1].feature_index;
unique_path[i].zero_fraction = unique_path[i+1].zero_fraction;
unique_path[i].one_fraction = unique_path[i+1].one_fraction;
}
}
inline double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
const double one_fraction = unique_path[path_index].one_fraction;
const double zero_fraction = unique_path[path_index].zero_fraction;
double next_one_portion = unique_path[unique_depth].pweight;
double total = 0;
for (int i = unique_depth-1; i >= 0; --i) {
if (one_fraction != 0) {
const double tmp = next_one_portion*(unique_depth+1)
/ static_cast<double>((i+1)*one_fraction);
total += tmp;
next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth-i)
/ static_cast<double>(unique_depth+1));
} else {
total += (unique_path[i].pweight/zero_fraction)/((unique_depth-i)
/ static_cast<double>(unique_depth+1));
}
}
return total;
}
// recursive computation of SHAP values for a decision tree
inline void Tree::TreeSHAP(const double *feature_values, double *phi,
int node, int unique_depth,
PathElement *parent_unique_path, double parent_zero_fraction,
double parent_one_fraction, int parent_feature_index) const {
// extend the unique path
PathElement *unique_path = parent_unique_path + unique_depth;
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path+unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index);
const int split_index = split_feature_[node];
// leaf node
if (node < 0) {
for (int i = 1; i <= unique_depth; ++i) {
const double w = UnwoundPathSum(unique_path, unique_depth, i);
const PathElement &el = unique_path[i];
phi[el.feature_index] += w*(el.one_fraction-el.zero_fraction)*leaf_value_[~node];
}
// internal node
} else {
const int hot_index = Decision(feature_values[split_index], node);
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
const double w = data_count(node);
const double hot_zero_fraction = data_count(hot_index)/w;
const double cold_zero_fraction = data_count(cold_index)/w;
double incoming_zero_fraction = 1;
double incoming_one_fraction = 1;
// see if we have already split on this feature,
// if so we undo that split so we can redo it for this node
int path_index = 0;
for (; path_index <= unique_depth; ++path_index) {
if (unique_path[path_index].feature_index == split_index) break;
}
if (path_index != unique_depth+1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction;
incoming_one_fraction = unique_path[path_index].one_fraction;
UnwindPath(unique_path, unique_depth, path_index);
unique_depth -= 1;
}
TreeSHAP(feature_values, phi, hot_index, unique_depth+1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index);
TreeSHAP(feature_values, phi, cold_index, unique_depth+1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index);
}
}
inline void Tree::PredictContrib(const double* feature_values, int num_features, double *output) const {
output[num_features] += ExpectedValue();
// Run the recursion with preallocated space for the unique path data
const int max_path_len = MaxDepth()+1;
PathElement *unique_path_data = new PathElement[(max_path_len*(max_path_len+1))/2];
TreeSHAP(feature_values, output, 0, 0, unique_path_data, 1, 1, -1);
delete[] unique_path_data;
}
inline double Tree::ExpectedValue(int node) const {
if (node >= 0) {
const int l = left_child_[node];
const int r = right_child_[node];
return (data_count(l)*ExpectedValue(l) + data_count(r)*ExpectedValue(r))/data_count(node);
} else {
return LeafOutput(~node);
}
}
inline int Tree::MaxDepth() const {
int max_depth = 0;
for (int i = 0; i < num_leaves(); ++i) {
if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
}
return max_depth;
}
inline int Tree::GetLeaf(const double* feature_values) const { inline int Tree::GetLeaf(const double* feature_values) const {
int node = 0; int node = 0;
if (has_categorical_) { if (has_categorical_) {
......
...@@ -111,7 +111,7 @@ void Application::LoadData() { ...@@ -111,7 +111,7 @@ void Application::LoadData() {
PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); PredictionEarlyStopInstance pred_early_stop = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
// need to continue training // need to continue training
if (boosting_->NumberOfTotalModel() > 0) { if (boosting_->NumberOfTotalModel() > 0) {
predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, -1, -1)); predictor.reset(new Predictor(boosting_.get(), -1, true, false, false, false, -1, -1));
predict_fun = predictor->GetPredictFunction(); predict_fun = predictor->GetPredictFunction();
} }
...@@ -236,8 +236,9 @@ void Application::Train() { ...@@ -236,8 +236,9 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
// create predictor // create predictor
Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score, Predictor predictor(boosting_.get(), config_.io_config.num_iteration_predict, config_.io_config.is_predict_raw_score,
config_.io_config.is_predict_leaf_index, config_.io_config.pred_early_stop, config_.io_config.is_predict_leaf_index, config_.io_config.is_predict_contrib,
config_.io_config.pred_early_stop_freq, config_.io_config.pred_early_stop_margin); config_.io_config.pred_early_stop, config_.io_config.pred_early_stop_freq,
config_.io_config.pred_early_stop_margin);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(),
config_.io_config.output_result.c_str(), config_.io_config.has_header); config_.io_config.output_result.c_str(), config_.io_config.has_header);
Log::Info("Finished prediction"); Log::Info("Finished prediction");
......
...@@ -28,10 +28,11 @@ public: ...@@ -28,10 +28,11 @@ public:
* \param boosting Input boosting model * \param boosting Input boosting model
* \param num_iteration Number of boosting round * \param num_iteration Number of boosting round
* \param is_raw_score True if need to predict result with raw score * \param is_raw_score True if need to predict result with raw score
* \param is_predict_leaf_index True if output leaf index instead of prediction score * \param is_predict_leaf_index True to output leaf index instead of prediction score
* \param is_predict_contrib True to output feature contributions instead of prediction score
*/ */
Predictor(Boosting* boosting, int num_iteration, Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index, bool is_raw_score, bool is_predict_leaf_index, bool is_predict_contrib,
bool early_stop, int early_stop_freq, double early_stop_margin) { bool early_stop, int early_stop_freq, double early_stop_margin) {
early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig()); early_stop_ = CreatePredictionEarlyStopInstance("none", LightGBM::PredictionEarlyStopConfig());
...@@ -53,7 +54,7 @@ public: ...@@ -53,7 +54,7 @@ public:
} }
boosting->InitPredict(num_iteration); boosting->InitPredict(num_iteration);
boosting_ = boosting; boosting_ = boosting;
num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index); num_pred_one_row_ = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf_index, is_predict_contrib);
num_feature_ = boosting_->MaxFeatureIdx() + 1; num_feature_ = boosting_->MaxFeatureIdx() + 1;
predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f)); predict_buf_ = std::vector<std::vector<double>>(num_threads_, std::vector<double>(num_feature_, 0.0f));
...@@ -66,6 +67,15 @@ public: ...@@ -66,6 +67,15 @@ public:
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features); ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
}; };
} else if (is_predict_contrib) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
// get result for leaf index
boosting_->PredictContrib(predict_buf_[tid].data(), output, &early_stop_);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
} else { } else {
if (is_raw_score) { if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) { predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
......
...@@ -739,6 +739,27 @@ const double* GBDT::GetTrainingScore(int64_t* out_len) { ...@@ -739,6 +739,27 @@ const double* GBDT::GetTrainingScore(int64_t* out_len) {
return train_score_updater_->score(); return train_score_updater_->score();
} }
void GBDT::PredictContrib(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
int early_stop_round_counter = 0;
// set zero
const int num_features = max_feature_idx_+1;
std::memset(output, 0, sizeof(double) * num_tree_per_iteration_ * (num_features+1));
for (int i = 0; i < num_iteration_for_pred_; ++i) {
// predict all the trees for one iteration
for (int k = 0; k < num_tree_per_iteration_; ++k) {
models_[i * num_tree_per_iteration_ + k]->PredictContrib(features, num_features, output + k*(num_features+1));
}
// check early stopping
++early_stop_round_counter;
if (early_stop->round_period == early_stop_round_counter) {
if (early_stop->callback_function(output, num_tree_per_iteration_)) {
return;
}
early_stop_round_counter = 0;
}
}
}
void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) { void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size())); CHECK(data_idx >= 0 && data_idx <= static_cast<int>(valid_score_updater_.size()));
......
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
bool EvalAndCheckEarlyStopping() override; bool EvalAndCheckEarlyStopping() override;
bool NeedAccuratePrediction() const override { bool NeedAccuratePrediction() const override {
if (objective_function_ == nullptr) { if (objective_function_ == nullptr) {
return true; return true;
} else { } else {
...@@ -133,7 +133,7 @@ public: ...@@ -133,7 +133,7 @@ public:
*/ */
void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override; void GetPredictAt(int data_idx, double* out_result, int64_t* out_len) override;
inline int NumPredictOneRow(int num_iteration, int is_pred_leaf) const override { inline int NumPredictOneRow(int num_iteration, bool is_pred_leaf, bool is_pred_contrib) const override {
int num_preb_in_one_row = num_class_; int num_preb_in_one_row = num_class_;
if (is_pred_leaf) { if (is_pred_leaf) {
int max_iteration = GetCurrentIteration(); int max_iteration = GetCurrentIteration();
...@@ -142,6 +142,8 @@ public: ...@@ -142,6 +142,8 @@ public:
} else { } else {
num_preb_in_one_row *= max_iteration; num_preb_in_one_row *= max_iteration;
} }
} else if (is_pred_contrib) {
num_preb_in_one_row = max_feature_idx_ + 2; // +1 for 0-based indexing, +1 for baseline
} }
return num_preb_in_one_row; return num_preb_in_one_row;
} }
...@@ -154,6 +156,9 @@ public: ...@@ -154,6 +156,9 @@ public:
void PredictLeafIndex(const double* features, double* output) const override; void PredictLeafIndex(const double* features, double* output) const override;
void PredictContrib(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop) const override;
/*! /*!
* \brief Dump model to json format string * \brief Dump model to json format string
* \param num_iteration Number of iterations that want to dump, -1 means dump all * \param num_iteration Number of iterations that want to dump, -1 means dump all
...@@ -299,7 +304,7 @@ protected: ...@@ -299,7 +304,7 @@ protected:
std::string OutputMetric(int iter); std::string OutputMetric(int iter);
/*! /*!
* \brief Calculate feature importances * \brief Calculate feature importances
* \param num_used_model Number of model that want to use for feature importance, -1 means use all * \param num_used_model Number of model that want to use for feature importance, -1 means use all
* \return sorted pairs of (feature_importance, feature_name) * \return sorted pairs of (feature_importance, feature_name)
*/ */
std::vector<std::pair<size_t, std::string>> FeatureImportance(int num_used_model) const; std::vector<std::pair<size_t, std::string>> FeatureImportance(int num_used_model) const;
......
...@@ -178,17 +178,20 @@ public: ...@@ -178,17 +178,20 @@ public:
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
bool is_predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true; is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) { } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true; is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
is_predict_contrib = true;
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf); int64_t num_preb_in_one_row = boosting_->NumPredictOneRow(num_iteration, is_predict_leaf, is_predict_contrib);
auto pred_fun = predictor.GetPredictFunction(); auto pred_fun = predictor.GetPredictFunction();
OMP_INIT_EX(); OMP_INIT_EX();
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
...@@ -209,14 +212,17 @@ public: ...@@ -209,14 +212,17 @@ public:
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool is_predict_leaf = false; bool is_predict_leaf = false;
bool is_raw_score = false; bool is_raw_score = false;
bool is_predict_contrib = false;
if (predict_type == C_API_PREDICT_LEAF_INDEX) { if (predict_type == C_API_PREDICT_LEAF_INDEX) {
is_predict_leaf = true; is_predict_leaf = true;
} else if (predict_type == C_API_PREDICT_RAW_SCORE) { } else if (predict_type == C_API_PREDICT_RAW_SCORE) {
is_raw_score = true; is_raw_score = true;
} else if (predict_type == C_API_PREDICT_CONTRIB) {
is_predict_contrib = true;
} else { } else {
is_raw_score = false; is_raw_score = false;
} }
Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, Predictor predictor(boosting_.get(), num_iteration, is_raw_score, is_predict_leaf, is_predict_contrib,
config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin); config.pred_early_stop, config.pred_early_stop_freq, config.pred_early_stop_margin);
bool bool_data_has_header = data_has_header > 0 ? true : false; bool bool_data_has_header = data_has_header > 0 ? true : false;
predictor.Predict(data_filename, result_filename, bool_data_has_header); predictor.Predict(data_filename, result_filename, bool_data_has_header);
...@@ -998,7 +1004,7 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle, ...@@ -998,7 +1004,7 @@ int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
*out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow( *out_len = static_cast<int64_t>(num_row * ref_booster->GetBoosting()->NumPredictOneRow(
num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX)); num_iteration, predict_type == C_API_PREDICT_LEAF_INDEX, predict_type == C_API_PREDICT_CONTRIB));
API_END(); API_END();
} }
......
...@@ -187,7 +187,7 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -187,7 +187,7 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
void OverallConfig::CheckParamConflict() { void OverallConfig::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match // check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass") bool objective_type_multiclass = (objective_type == std::string("multiclass")
|| objective_type == std::string("multiclassova")); || objective_type == std::string("multiclassova"));
int num_class_check = boosting_config.num_class; int num_class_check = boosting_config.num_class;
if (objective_type_multiclass) { if (objective_type_multiclass) {
...@@ -201,7 +201,7 @@ void OverallConfig::CheckParamConflict() { ...@@ -201,7 +201,7 @@ void OverallConfig::CheckParamConflict() {
} }
if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) { if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
for (std::string metric_type : metric_types) { for (std::string metric_type : metric_types) {
bool metric_type_multiclass = (metric_type == std::string("multi_logloss") bool metric_type_multiclass = (metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")); || metric_type == std::string("multi_error"));
if ((objective_type_multiclass && !metric_type_multiclass) if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) { || (!objective_type_multiclass && metric_type_multiclass)) {
...@@ -209,7 +209,7 @@ void OverallConfig::CheckParamConflict() { ...@@ -209,7 +209,7 @@ void OverallConfig::CheckParamConflict() {
} }
} }
} }
if (network_config.num_machines > 1) { if (network_config.num_machines > 1) {
is_parallel = true; is_parallel = true;
} else { } else {
...@@ -229,7 +229,7 @@ void OverallConfig::CheckParamConflict() { ...@@ -229,7 +229,7 @@ void OverallConfig::CheckParamConflict() {
} else if (boosting_config.tree_learner_type == std::string("data") } else if (boosting_config.tree_learner_type == std::string("data")
|| boosting_config.tree_learner_type == std::string("voting")) { || boosting_config.tree_learner_type == std::string("voting")) {
is_parallel_find_bin = true; is_parallel_find_bin = true;
if (boosting_config.tree_config.histogram_pool_size >= 0 if (boosting_config.tree_config.histogram_pool_size >= 0
&& boosting_config.tree_learner_type == std::string("data")) { && boosting_config.tree_learner_type == std::string("data")) {
Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs" Log::Warning("Histogram LRU queue was enabled (histogram_pool_size=%f). Will disable this to reduce communication costs"
, boosting_config.tree_config.histogram_pool_size); , boosting_config.tree_config.histogram_pool_size);
...@@ -257,6 +257,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -257,6 +257,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file); GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_predict_raw_score", &is_predict_raw_score); GetBool(params, "is_predict_raw_score", &is_predict_raw_score);
GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index); GetBool(params, "is_predict_leaf_index", &is_predict_leaf_index);
GetBool(params, "is_predict_contrib", &is_predict_contrib);
GetInt(params, "snapshot_freq", &snapshot_freq); GetInt(params, "snapshot_freq", &snapshot_freq);
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment