tree.cpp 40.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
8
9
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
10

11
12
13
14
#include <functional>
#include <iomanip>
#include <sstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
Tree::Tree(int max_leaves, bool track_branch_features, bool is_linear)
18
  :max_leaves_(max_leaves), track_branch_features_(track_branch_features) {
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
29
  leaf_weight_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
30
31
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
32
  internal_weight_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
33
34
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
35
36
37
  if (track_branch_features_) {
    branch_features_ = std::vector<std::vector<int>>(max_leaves_);
  }
Guolin Ke's avatar
Guolin Ke committed
38
39
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
40
  num_leaves_ = 1;
41
  leaf_value_[0] = 0.0f;
42
  leaf_weight_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
43
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
44
  shrinkage_ = 1.0f;
45
  num_cat_ = 0;
46
47
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
48
  max_depth_ = -1;
49
50
51
52
53
54
55
  is_linear_ = is_linear;
  if (is_linear_) {
    leaf_coeff_.resize(max_leaves_);
    leaf_const_ = std::vector<double>(max_leaves_, 0);
    leaf_features_.resize(max_leaves_);
    leaf_features_inner_.resize(max_leaves_);
  }
Guolin Ke's avatar
Guolin Ke committed
56
}
Guolin Ke's avatar
Guolin Ke committed
57

58
59
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
60
61
                int left_cnt, int right_cnt, double left_weight, double right_weight, float gain,
                MissingType missing_type, bool default_left) {
62
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
Guolin Ke's avatar
Guolin Ke committed
63
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
64
  decision_type_[new_node_idx] = 0;
65
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
66
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
67
  SetMissingType(&decision_type_[new_node_idx], static_cast<int8_t>(missing_type));
Guolin Ke's avatar
Guolin Ke committed
68
  threshold_in_bin_[new_node_idx] = threshold_bin;
Guolin Ke's avatar
Guolin Ke committed
69
  threshold_[new_node_idx] = threshold_double;
70
71
72
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
73

74
75
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
76
77
                           data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
78
79
80
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
81
  SetMissingType(&decision_type_[new_node_idx], static_cast<int8_t>(missing_type));
82
83
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
84
  ++num_cat_;
85
86
87
88
89
90
91
92
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
  ++num_leaves_;
  return num_leaves_ - 1;
}

Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, \
                      data_idx)                                               \
  std::vector<std::unique_ptr<BinIterator>> iter((niter));                    \
  for (int i = 0; i < (niter); ++i) {                                         \
    iter[i].reset(data->FeatureIterator((fidx_in_iter)));                     \
    iter[i]->Reset((start_pos));                                              \
  }                                                                           \
  for (data_size_t i = start; i < end; ++i) {                                 \
    int node = 0;                                                             \
    while (node >= 0) {                                                       \
      node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node,            \
                          default_bins[node], max_bins[node]);                \
    }                                                                         \
    score[(data_idx)] += static_cast<double>(leaf_value_[~node]);             \
111
112
  }\

113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146

#define PredictionFunLinear(niter, fidx_in_iter, start_pos, decision_fun,     \
                            iter_idx, data_idx)                               \
  std::vector<std::unique_ptr<BinIterator>> iter((niter));                    \
  for (int i = 0; i < (niter); ++i) {                                         \
    iter[i].reset(data->FeatureIterator((fidx_in_iter)));                     \
    iter[i]->Reset((start_pos));                                              \
  }                                                                           \
  for (data_size_t i = start; i < end; ++i) {                                 \
    int node = 0;                                                             \
    while (node >= 0) {                                                       \
      node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node,            \
                          default_bins[node], max_bins[node]);                \
    }                                                                         \
    double add_score = leaf_const_[~node];                                    \
    bool nan_found = false;                                                   \
    const double* coeff_ptr = leaf_coeff_[~node].data();                      \
    const float** data_ptr = feat_ptr[~node].data();                          \
    for (size_t j = 0; j < leaf_features_inner_[~node].size(); ++j) {         \
       float feat_val = data_ptr[j][(data_idx)];                              \
       if (std::isnan(feat_val)) {                                            \
          nan_found = true;                                                   \
          break;                                                              \
       }                                                                      \
       add_score += coeff_ptr[j] * feat_val;                                  \
    }                                                                         \
    if (nan_found) {                                                          \
       score[(data_idx)] += leaf_value_[~node];                               \
    } else {                                                                  \
      score[(data_idx)] += add_score;                                         \
    }                                                                         \
}\


147
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
148
  if (!is_linear_ && num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
149
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
150
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
157
158
159
160
161
162
163
164
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
  if (is_linear_) {
    std::vector<std::vector<const float*>> feat_ptr(num_leaves_);
    for (int leaf_num = 0; leaf_num < num_leaves_; ++leaf_num) {
      for (int feat : leaf_features_inner_[leaf_num]) {
        feat_ptr[leaf_num].push_back(data->raw_index(feat));
      }
    }
    if (num_cat_ > 0) {
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
        });
      }
184
    } else {
185
186
187
188
189
190
191
192
193
194
195
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
        });
      }
196
    }
Guolin Ke's avatar
Guolin Ke committed
197
  } else {
198
199
200
201
202
203
204
205
206
207
208
209
    if (num_cat_ > 0) {
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
        });
      }
210
    } else {
211
212
213
214
215
216
217
218
219
220
221
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
        });
      }
222
    }
Guolin Ke's avatar
Guolin Ke committed
223
  }
Guolin Ke's avatar
Guolin Ke committed
224
225
}

Guolin Ke's avatar
Guolin Ke committed
226
void Tree::AddPredictionToScore(const Dataset* data,
227
228
229
  const data_size_t* used_data_indices,
  data_size_t num_data, double* score) const {
  if (!is_linear_ && num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
230
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
231
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
232
233
234
235
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
236
    return;
Guolin Ke's avatar
Guolin Ke committed
237
  }
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
244
245
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
  if (is_linear_) {
    std::vector<std::vector<const float*>> feat_ptr(num_leaves_);
    for (int leaf_num = 0; leaf_num < num_leaves_; ++leaf_num) {
      for (int feat : leaf_features_inner_[leaf_num]) {
        feat_ptr[leaf_num].push_back(data->raw_index(feat));
      }
    }
    if (num_cat_ > 0) {
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner,
                              node, used_data_indices[i]);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
        });
      }
266
    } else {
267
268
269
270
271
272
273
274
275
276
277
278
279
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner,
                              node, used_data_indices[i]);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins, &feat_ptr]
        (int, data_size_t start, data_size_t end) {
          PredictionFunLinear(data->num_features(), i, used_data_indices[start], NumericalDecisionInner,
                              split_feature_inner_[node], used_data_indices[i]);
        });
      }
280
    }
Guolin Ke's avatar
Guolin Ke committed
281
  } else {
282
283
284
285
286
287
288
289
290
291
292
293
    if (num_cat_ > 0) {
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
        });
      }
294
    } else {
295
296
297
298
299
300
301
302
303
304
305
      if (data->num_features() > num_leaves_ - 1) {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
        });
      } else {
        Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
        (int, data_size_t start, data_size_t end) {
          PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
        });
      }
306
    }
Guolin Ke's avatar
Guolin Ke committed
307
  }
Guolin Ke's avatar
Guolin Ke committed
308
309
}

310
#undef PredictionFun
311
#undef PredictionFunLinear
312

313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
double Tree::GetUpperBoundValue() const {
  double upper_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] > upper_bound) {
      upper_bound = leaf_value_[i];
    }
  }
  return upper_bound;
}

double Tree::GetLowerBoundValue() const {
  double lower_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] < lower_bound) {
      lower_bound = leaf_value_[i];
    }
  }
  return lower_bound;
}

Guolin Ke's avatar
Guolin Ke committed
333
std::string Tree::ToString() const {
334
  std::stringstream str_buf;
335
336
337
338
339
340
341
342
  Common::C_stringstream(str_buf);

  #if ((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__)))
  using CommonLegacy::ArrayToString;  // Slower & unsafe regarding locale.
  #else
  using CommonC::ArrayToString;
  #endif

343
344
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
345
  str_buf << "split_feature="
346
    << ArrayToString(split_feature_, num_leaves_ - 1) << '\n';
347
  str_buf << "split_gain="
348
    << ArrayToString(split_gain_, num_leaves_ - 1) << '\n';
349
  str_buf << "threshold="
350
    << ArrayToString<true>(threshold_, num_leaves_ - 1) << '\n';
351
  str_buf << "decision_type="
352
    << ArrayToString(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
353
  str_buf << "left_child="
354
    << ArrayToString(left_child_, num_leaves_ - 1) << '\n';
355
  str_buf << "right_child="
356
    << ArrayToString(right_child_, num_leaves_ - 1) << '\n';
357
  str_buf << "leaf_value="
358
    << ArrayToString<true>(leaf_value_, num_leaves_) << '\n';
359
  str_buf << "leaf_weight="
360
    << ArrayToString<true>(leaf_weight_, num_leaves_) << '\n';
361
  str_buf << "leaf_count="
362
    << ArrayToString(leaf_count_, num_leaves_) << '\n';
363
  str_buf << "internal_value="
364
    << ArrayToString(internal_value_, num_leaves_ - 1) << '\n';
365
  str_buf << "internal_weight="
366
    << ArrayToString(internal_weight_, num_leaves_ - 1) << '\n';
367
  str_buf << "internal_count="
368
    << ArrayToString(internal_count_, num_leaves_ - 1) << '\n';
369
370
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
371
      << ArrayToString(cat_boundaries_, num_cat_ + 1) << '\n';
372
    str_buf << "cat_threshold="
373
      << ArrayToString(cat_threshold_, cat_threshold_.size()) << '\n';
374
  }
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
  str_buf << "is_linear=" << is_linear_ << '\n';

  if (is_linear_) {
    str_buf << "leaf_const="
      << ArrayToString(leaf_const_, num_leaves_) << '\n';
    std::vector<int> num_feat(num_leaves_);
    for (int i = 0; i < num_leaves_; ++i) {
      num_feat[i] = leaf_coeff_[i].size();
    }
    str_buf << "num_features="
      << ArrayToString(num_feat, num_leaves_) << '\n';
    str_buf << "leaf_features=";
    for (int i = 0; i < num_leaves_; ++i) {
      if (num_feat[i] > 0) {
        str_buf << ArrayToString(leaf_features_[i], leaf_features_[i].size()) << ' ';
      }
      str_buf << ' ';
    }
    str_buf << '\n';
    str_buf << "leaf_coeff=";
    for (int i = 0; i < num_leaves_; ++i) {
      if (num_feat[i] > 0) {
        str_buf << ArrayToString(leaf_coeff_[i], leaf_coeff_[i].size()) << ' ';
      }
      str_buf << ' ';
    }
    str_buf << '\n';
  }
403
404
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
405

406
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
407
408
}

Guolin Ke's avatar
Guolin Ke committed
409
std::string Tree::ToJSON() const {
410
  std::stringstream str_buf;
411
  Common::C_stringstream(str_buf);
412
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
413
414
415
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
416
  if (num_leaves_ == 1) {
417
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
418
  } else {
419
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
420
  }
wxchan's avatar
wxchan committed
421

422
  return str_buf.str();
wxchan's avatar
wxchan committed
423
424
}

Guolin Ke's avatar
Guolin Ke committed
425
std::string Tree::NodeToJSON(int index) const {
426
  std::stringstream str_buf;
427
  Common::C_stringstream(str_buf);
428
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
429
430
  if (index >= 0) {
    // non-leaf
431
432
433
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
Guolin Ke's avatar
Guolin Ke committed
434
    str_buf << "\"split_gain\":" << Common::AvoidInf(split_gain_[index]) << "," << '\n';
435
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
436
437
438
439
440
441
442
443
444
445
446
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
447
      str_buf << "\"threshold\":\"" << CommonC::Join(cats, "||") << "\"," << '\n';
448
      str_buf << "\"decision_type\":\"==\"," << '\n';
449
    } else {
450
451
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
452
453
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
454
      str_buf << "\"default_left\":true," << '\n';
455
    } else {
456
      str_buf << "\"default_left\":false," << '\n';
457
458
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
459
    if (missing_type == MissingType::None) {
460
      str_buf << "\"missing_type\":\"None\"," << '\n';
461
    } else if (missing_type == MissingType::Zero) {
462
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
463
    } else {
464
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
465
    }
466
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
467
    str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
468
469
470
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
471
    str_buf << "}";
wxchan's avatar
wxchan committed
472
473
474
  } else {
    // leaf
    index = ~index;
475
476
477
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
478
    str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
479
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
480
    str_buf << "}";
wxchan's avatar
wxchan committed
481
482
  }

483
  return str_buf.str();
wxchan's avatar
wxchan committed
484
485
}

Guolin Ke's avatar
Guolin Ke committed
486
487
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
488
  Common::C_stringstream(str_buf);
Guolin Ke's avatar
Guolin Ke committed
489
490
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Nikita Titov's avatar
Nikita Titov committed
491
  if (missing_type == MissingType::None
492
      || (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
493
    str_buf << "if (fval <= " << threshold_[node] << ") {";
494
  } else if (missing_type == MissingType::Zero) {
Guolin Ke's avatar
Guolin Ke committed
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
513
  Common::C_stringstream(str_buf);
514
  if (missing_type == MissingType::NaN) {
Guolin Ke's avatar
Guolin Ke committed
515
516
517
518
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
519
  int cat_idx = static_cast<int>(threshold_[node]);
520
521
522
523
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
524
525
526
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
527
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
528
  std::stringstream str_buf;
529
  Common::C_stringstream(str_buf);
530
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
531
  if (predict_leaf_index) {
532
533
534
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
535
536
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
537
  } else {
538
539
540
541
542
543
544
545
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
546
547
548
549
550
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
551
    str_buf << NodeToIfElse(0, predict_leaf_index);
552
  }
553
  str_buf << " }" << '\n';
554

555
  // Predict func by Map to ifelse
556
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
557
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
558
559
560
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
561
562
563
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
579
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
580
  }
581
  str_buf << " }" << '\n';
582

583
584
585
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
586
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
587
  std::stringstream str_buf;
588
  Common::C_stringstream(str_buf);
589
590
591
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
592
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
593
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
594
      str_buf << NumericalDecisionIfElse(index);
595
    } else {
596
      str_buf << CategoricalDecisionIfElse(index);
597
598
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
599
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
600
601
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
602
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
603
604
605
606
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
607
    if (predict_leaf_index) {
608
609
610
611
612
613
614
615
616
617
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
618
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
619
  std::stringstream str_buf;
620
  Common::C_stringstream(str_buf);
621
622
623
624
625
626
627
628
629
630
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
631
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
632
633
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
634
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
635
636
637
638
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
639
    if (predict_leaf_index) {
640
641
642
643
644
645
646
647
648
649
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

650
651
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
652
  std::unordered_map<std::string, std::string> key_vals;
653
  const int max_num_line = 22;
654
655
656
657
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
658
    while (*p != '=') ++p;
659
660
661
662
663
664
665
666
667
668
669
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

670
  if (key_vals.count("num_leaves") <= 0) {
671
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
672
673
674
675
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

676
  if (key_vals.count("num_cat") <= 0) {
677
    Log::Fatal("Tree model should contain num_cat field");
678
679
680
681
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

682
  if (key_vals.count("leaf_value")) {
683
    leaf_value_ = CommonC::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
684
685
686
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }
687

Guolin Ke's avatar
Guolin Ke committed
688
  if (key_vals.count("shrinkage")) {
689
    CommonC::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
Guolin Ke's avatar
Guolin Ke committed
690
691
692
  } else {
    shrinkage_ = 1.0f;
  }
693

694
695
696
697
698
699
700
  if (key_vals.count("is_linear")) {
    int is_linear_int;
    Common::Atoi(key_vals["is_linear"].c_str(), &is_linear_int);
    is_linear_ = static_cast<bool>(is_linear_int);
  }

  if ((num_leaves_ <= 1) && !is_linear_) { return; }
701

Guolin Ke's avatar
Guolin Ke committed
702
  if (key_vals.count("left_child")) {
703
    left_child_ = CommonC::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
704
705
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
706
707
  }

Guolin Ke's avatar
Guolin Ke committed
708
  if (key_vals.count("right_child")) {
709
    right_child_ = CommonC::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
710
711
712
713
714
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
715
    split_feature_ = CommonC::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
716
717
718
719
720
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
721
    threshold_ = CommonC::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
722
723
724
725
726
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
727
    split_gain_ = CommonC::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
728
729
730
731
732
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
733
    internal_count_ = CommonC::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
734
735
736
737
738
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
739
    internal_value_ = CommonC::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
740
741
742
743
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

744
  if (key_vals.count("internal_weight")) {
745
    internal_weight_ = CommonC::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
746
  } else {
747
748
749
750
    internal_weight_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_weight")) {
751
    leaf_weight_ = CommonC::StringToArray<double>(key_vals["leaf_weight"], num_leaves_);
752
  } else {
753
754
755
    leaf_weight_.resize(num_leaves_);
  }

Guolin Ke's avatar
Guolin Ke committed
756
  if (key_vals.count("leaf_count")) {
757
    leaf_count_ = CommonC::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
758
759
760
761
762
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
763
    decision_type_ = CommonC::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
764
765
766
767
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
  if (is_linear_) {
    if (key_vals.count("leaf_const")) {
      leaf_const_ = Common::StringToArrayFast<double>(key_vals["leaf_const"], num_leaves_);
    } else {
      leaf_const_.resize(num_leaves_);
    }
    std::vector<int> num_feat;
    if (key_vals.count("num_features")) {
      num_feat = Common::StringToArrayFast<int>(key_vals["num_features"], num_leaves_);
    }
    leaf_coeff_.resize(num_leaves_);
    leaf_features_.resize(num_leaves_);
    leaf_features_inner_.resize(num_leaves_);
    if (num_feat.size() > 0) {
      int total_num_feat = 0;
      for (size_t i = 0; i < num_feat.size(); ++i) { total_num_feat += num_feat[i]; }
      std::vector<int> all_leaf_features;
      if (key_vals.count("leaf_features")) {
        all_leaf_features = Common::StringToArrayFast<int>(key_vals["leaf_features"], total_num_feat);
      }
      std::vector<double> all_leaf_coeff;
      if (key_vals.count("leaf_coeff")) {
        all_leaf_coeff = Common::StringToArrayFast<double>(key_vals["leaf_coeff"], total_num_feat);
      }
      int sum_num_feat = 0;
      for (int i = 0; i < num_leaves_; ++i) {
        if (num_feat[i] > 0) {
          if (key_vals.count("leaf_features"))  {
            leaf_features_[i].assign(all_leaf_features.begin() + sum_num_feat, all_leaf_features.begin() + sum_num_feat + num_feat[i]);
          }
          if (key_vals.count("leaf_coeff")) {
            leaf_coeff_[i].assign(all_leaf_coeff.begin() + sum_num_feat, all_leaf_coeff.begin() + sum_num_feat + num_feat[i]);
          }
        }
        sum_num_feat += num_feat[i];
      }
    }
  }

807
808
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
809
      cat_boundaries_ = CommonC::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
810
811
812
813
814
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
815
      cat_threshold_ = CommonC::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
816
    } else {
817
      Log::Fatal("Tree model should contain cat_threshold field");
818
819
    }
  }
820
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
821
822
}

Guolin Ke's avatar
Guolin Ke committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
888
  PathElement* unique_path = parent_unique_path + unique_depth;
889
890
891
  if (unique_depth > 0) {
    std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  }
Guolin Ke's avatar
Guolin Ke committed
892
893
894
895
896
897
898
899
900
901
902
903
904
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
905
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
906
907
908
909
910
911
912
913
914
915
916
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
917
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
918
919
920
921
922
923
924
925
926
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
927
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
928
929

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
930
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
931
932
933
  }
}

934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
// recursive sparse computation of SHAP values for a decision tree
void Tree::TreeSHAPByMap(const std::unordered_map<int, double>& feature_values, std::unordered_map<int, double>* phi,
                         int node, int unique_depth,
                         PathElement *parent_unique_path, double parent_zero_fraction,
                         double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
  PathElement* unique_path = parent_unique_path + unique_depth;
  if (unique_depth > 0) {
    std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  }
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      (*phi)[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

  // internal node
  } else {
    const int hot_index = Decision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node);
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAPByMap(feature_values, phi, hot_index, unique_depth + 1, unique_path,
                  hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);

    TreeSHAPByMap(feature_values, phi, cold_index, unique_depth + 1, unique_path,
                  cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
  }
}

986
987
988
989
990
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
991
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
992
  }
993
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
994
995
}

996
997
998
999
1000
1001
1002
1003
1004
1005
1006
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
1007
1008
1009
  }
}

Guolin Ke's avatar
Guolin Ke committed
1010
}  // namespace LightGBM