"include/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "1cacaef96158266189c7f394368117e88a8ab395"
Commit 67c2bdf9 authored by Scott Lundberg's avatar Scott Lundberg Committed by Guolin Ke
Browse files

Fix feature attributions for regression models and add Python bindings (#861)

* Fix feature attributions for regression models and add Python bindings

* Address pylint issue

* Lazy fix missing tree depth info
parent 5543979b
...@@ -111,7 +111,7 @@ public: ...@@ -111,7 +111,7 @@ public:
inline int PredictLeafIndex(const double* feature_values) const; inline int PredictLeafIndex(const double* feature_values) const;
inline void PredictContrib(const double* feature_values, int num_features, double* output) const; inline void PredictContrib(const double* feature_values, int num_features, double* output);
/*! \brief Get Number of leaves*/ /*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; } inline int num_leaves() const { return num_leaves_; }
...@@ -299,9 +299,12 @@ private: ...@@ -299,9 +299,12 @@ private:
/*! \brief Serialize one node to if-else statement*/ /*! \brief Serialize one node to if-else statement*/
std::string NodeToIfElse(int index, bool is_predict_leaf_index) const; std::string NodeToIfElse(int index, bool is_predict_leaf_index) const;
double ExpectedValue(int node) const; double ExpectedValue() const;
int MaxDepth() const; int MaxDepth();
/*! \brief This is used fill in leaf_depth_ after reloading a model*/
inline void RecomputeLeafDepths(int node = 0, int depth = 0);
/*! /*!
* \brief Used by TreeSHAP for data we keep about our decision path * \brief Used by TreeSHAP for data we keep about our decision path
...@@ -431,12 +434,25 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const { ...@@ -431,12 +434,25 @@ inline int Tree::PredictLeafIndex(const double* feature_values) const {
} }
} }
inline void Tree::PredictContrib(const double* feature_values, int num_features, double* output) const { inline void Tree::PredictContrib(const double* feature_values, int num_features, double* output) {
output[num_features] += ExpectedValue(0); output[num_features] += ExpectedValue();
// Run the recursion with preallocated space for the unique path data // Run the recursion with preallocated space for the unique path data
const int max_path_len = MaxDepth() + 1; if (num_leaves_ > 1) {
std::vector<PathElement> unique_path_data((max_path_len*(max_path_len + 1)) / 2); const int max_path_len = MaxDepth()+1;
TreeSHAP(feature_values, output, 0, 0, unique_path_data.data(), 1, 1, -1); PathElement *unique_path_data = new PathElement[(max_path_len*(max_path_len+1))/2];
TreeSHAP(feature_values, output, 0, 0, unique_path_data, 1, 1, -1);
delete[] unique_path_data;
}
}
inline void Tree::RecomputeLeafDepths(int node, int depth) {
if (node == 0) leaf_depth_.resize(num_leaves());
if (node < 0) {
leaf_depth_[~node] = depth;
} else {
RecomputeLeafDepths(left_child_[node], depth+1);
RecomputeLeafDepths(right_child_[node], depth+1);
}
} }
inline int Tree::GetLeaf(const double* feature_values) const { inline int Tree::GetLeaf(const double* feature_values) const {
......
...@@ -170,6 +170,7 @@ C_API_IS_ROW_MAJOR = 1 ...@@ -170,6 +170,7 @@ C_API_IS_ROW_MAJOR = 1
C_API_PREDICT_NORMAL = 0 C_API_PREDICT_NORMAL = 0
C_API_PREDICT_RAW_SCORE = 1 C_API_PREDICT_RAW_SCORE = 1
C_API_PREDICT_LEAF_INDEX = 2 C_API_PREDICT_LEAF_INDEX = 2
C_API_PREDICT_CONTRIB = 3
"""data type of data field""" """data type of data field"""
FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32, FIELD_TYPE_MAPPER = {"label": C_API_DTYPE_FLOAT32,
...@@ -351,7 +352,7 @@ class _InnerPredictor(object): ...@@ -351,7 +352,7 @@ class _InnerPredictor(object):
return this return this
def predict(self, data, num_iteration=-1, def predict(self, data, num_iteration=-1,
raw_score=False, pred_leaf=False, data_has_header=False, raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False,
is_reshape=True): is_reshape=True):
""" """
Predict logic Predict logic
...@@ -367,6 +368,8 @@ class _InnerPredictor(object): ...@@ -367,6 +368,8 @@ class _InnerPredictor(object):
True for predict raw score True for predict raw score
pred_leaf : bool pred_leaf : bool
True for predict leaf index True for predict leaf index
pred_contrib : bool
True for predict feature contributions
data_has_header : bool data_has_header : bool
Used for txt data, True if txt data has header Used for txt data, True if txt data has header
is_reshape : bool is_reshape : bool
...@@ -384,6 +387,8 @@ class _InnerPredictor(object): ...@@ -384,6 +387,8 @@ class _InnerPredictor(object):
predict_type = C_API_PREDICT_RAW_SCORE predict_type = C_API_PREDICT_RAW_SCORE
if pred_leaf: if pred_leaf:
predict_type = C_API_PREDICT_LEAF_INDEX predict_type = C_API_PREDICT_LEAF_INDEX
if pred_contrib:
predict_type = C_API_PREDICT_CONTRIB
int_data_has_header = 1 if data_has_header else 0 int_data_has_header = 1 if data_has_header else 0
if num_iteration > self.num_total_iteration: if num_iteration > self.num_total_iteration:
num_iteration = self.num_total_iteration num_iteration = self.num_total_iteration
...@@ -1653,7 +1658,7 @@ class Booster(object): ...@@ -1653,7 +1658,7 @@ class Booster(object):
ptr_string_buffer)) ptr_string_buffer))
return json.loads(string_buffer.value.decode()) return json.loads(string_buffer.value.decode())
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, pred_parameter=None): data_has_header=False, is_reshape=True, pred_parameter=None):
"""Make a prediction. """Make a prediction.
...@@ -1669,6 +1674,8 @@ class Booster(object): ...@@ -1669,6 +1674,8 @@ class Booster(object):
Whether to predict raw scores. Whether to predict raw scores.
pred_leaf : bool, optional (default=False) pred_leaf : bool, optional (default=False)
Whether to predict leaf index. Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Whether to predict feature contributions.
data_has_header : bool, optional (default=False) data_has_header : bool, optional (default=False)
Whether the data has header. Whether the data has header.
Used only if data is string. Used only if data is string.
...@@ -1685,7 +1692,7 @@ class Booster(object): ...@@ -1685,7 +1692,7 @@ class Booster(object):
predictor = self._to_predictor(pred_parameter) predictor = self._to_predictor(pred_parameter)
if num_iteration <= 0: if num_iteration <= 0:
num_iteration = self.best_iteration num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape) return predictor.predict(data, num_iteration, raw_score, pred_leaf, pred_contrib, data_has_header, is_reshape)
def get_leaf_output(self, tree_id, leaf_id): def get_leaf_output(self, tree_id, leaf_id):
"""Get the output of a leaf. """Get the output of a leaf.
......
...@@ -589,7 +589,6 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, ...@@ -589,7 +589,6 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path); if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
ExtendPath(unique_path, unique_depth, parent_zero_fraction, ExtendPath(unique_path, unique_depth, parent_zero_fraction,
parent_one_fraction, parent_feature_index); parent_one_fraction, parent_feature_index);
const int split_index = split_feature_[node];
// leaf node // leaf node
if (node < 0) { if (node < 0) {
...@@ -601,7 +600,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, ...@@ -601,7 +600,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
// internal node // internal node
} else { } else {
const int hot_index = Decision(feature_values[split_index], node); const int hot_index = Decision(feature_values[split_feature_[node]], node);
const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]); const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
const double w = data_count(node); const double w = data_count(node);
const double hot_zero_fraction = data_count(hot_index) / w; const double hot_zero_fraction = data_count(hot_index) / w;
...@@ -613,7 +612,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, ...@@ -613,7 +612,7 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
// if so we undo that split so we can redo it for this node // if so we undo that split so we can redo it for this node
int path_index = 0; int path_index = 0;
for (; path_index <= unique_depth; ++path_index) { for (; path_index <= unique_depth; ++path_index) {
if (unique_path[path_index].feature_index == split_index) break; if (unique_path[path_index].feature_index == split_feature_[node]) break;
} }
if (path_index != unique_depth + 1) { if (path_index != unique_depth + 1) {
incoming_zero_fraction = unique_path[path_index].zero_fraction; incoming_zero_fraction = unique_path[path_index].zero_fraction;
...@@ -623,25 +622,26 @@ void Tree::TreeSHAP(const double *feature_values, double *phi, ...@@ -623,25 +622,26 @@ void Tree::TreeSHAP(const double *feature_values, double *phi,
} }
TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path, TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_index); hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path, TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
cold_zero_fraction*incoming_zero_fraction, 0, split_index); cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
} }
} }
double Tree::ExpectedValue() const {
double Tree::ExpectedValue(int node) const { if (num_leaves_ == 1) return LeafOutput(0);
if (node >= 0) { const double total_count = internal_count_[0];
const int l = left_child_[node]; double exp_value = 0.0;
const int r = right_child_[node]; for (int i = 0; i < num_leaves(); ++i) {
return (data_count(l)*ExpectedValue(l) + data_count(r)*ExpectedValue(r)) / data_count(node); exp_value += (leaf_count_[i]/total_count)*LeafOutput(i);
} else {
return LeafOutput(~node);
} }
return exp_value;
} }
int Tree::MaxDepth() const { int Tree::MaxDepth() {
if (leaf_depth_.size() == 0) RecomputeLeafDepths();
if (num_leaves_ == 1) return 0;
int max_depth = 0; int max_depth = 0;
for (int i = 0; i < num_leaves(); ++i) { for (int i = 0; i < num_leaves(); ++i) {
if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i]; if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
......
...@@ -462,3 +462,23 @@ class TestEngine(unittest.TestCase): ...@@ -462,3 +462,23 @@ class TestEngine(unittest.TestCase):
tmp_dat_val = tmp_dat.subset(np.arange(80, 100)).subset(np.arange(18)) tmp_dat_val = tmp_dat.subset(np.arange(80, 100)).subset(np.arange(18))
params = {'objective': 'regression_l2', 'metric': 'rmse'} params = {'objective': 'regression_l2', 'metric': 'rmse'}
gbm = lgb.train(params, tmp_dat_train, num_boost_round=20, valid_sets=[tmp_dat_train, tmp_dat_val]) gbm = lgb.train(params, tmp_dat_train, num_boost_round=20, valid_sets=[tmp_dat_train, tmp_dat_val])
def test_contribs(self):
X, y = load_breast_cancer(True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'verbose': -1,
'num_iteration': 50 # test num_iteration in dict here
}
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
evals_result = {}
gbm = lgb.train(params, lgb_train,
num_boost_round=20,
valid_sets=lgb_eval,
verbose_eval=False,
evals_result=evals_result)
self.assertLess(np.linalg.norm(gbm.predict(X_test, raw_score=True) - np.sum(gbm.predict(X_test, pred_contrib=True), axis=1)), 1e-4)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment