tree.cpp 26.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/tree.h>

#include <LightGBM/utils/threading.h>
#include <LightGBM/utils/common.h>

#include <LightGBM/dataset.h>

#include <sstream>
#include <unordered_map>
#include <functional>
#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
13
#include <memory>
14
#include <iomanip>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {
Guolin Ke's avatar
Guolin Ke committed
20
21
22
23
24
25
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
26
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
27
28
29
30
31
32
33
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
34
35
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
36
  num_leaves_ = 1;
37
  leaf_value_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
38
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
39
  shrinkage_ = 1.0f;
40
  num_cat_ = 0;
41
42
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
43
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
44
}
Guolin Ke's avatar
Guolin Ke committed
45

Guolin Ke's avatar
Guolin Ke committed
46
47
48
Tree::~Tree() {
}

49
50
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
51
                int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
52
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
Guolin Ke's avatar
Guolin Ke committed
53
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
54
  decision_type_[new_node_idx] = 0;
55
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
56
57
58
59
60
61
62
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
63
  }
Guolin Ke's avatar
Guolin Ke committed
64
  threshold_in_bin_[new_node_idx] = threshold_bin;
65
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
66
67
68
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
69

70
71
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
72
                           data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
73
74
75
76
77
78
79
80
81
82
83
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
84
85
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
86
  ++num_cat_;
87
88
89
90
91
92
93
94
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
95
96
97
98
  ++num_leaves_;
  return num_leaves_ - 1;
}

99
100
101
102
103
104
105
106
107
108
109
110
111
112
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

113
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
114
115
116
117
118
119
120
121
122
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
123
124
125
126
127
128
129
130
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
131
  if (num_cat_ > 0) {
132
    if (data->num_features() > num_leaves_ - 1) {
133
134
135
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
136
137
      });
    } else {
138
139
140
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
141
142
      });
    }
Guolin Ke's avatar
Guolin Ke committed
143
  } else {
144
    if (data->num_features() > num_leaves_ - 1) {
145
146
147
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
148
149
      });
    } else {
150
151
152
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
153
154
      });
    }
Guolin Ke's avatar
Guolin Ke committed
155
  }
Guolin Ke's avatar
Guolin Ke committed
156
157
}

Guolin Ke's avatar
Guolin Ke committed
158
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
159
160
161
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
162
163
164
165
166
167
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
168
    return;
Guolin Ke's avatar
Guolin Ke committed
169
  }
Guolin Ke's avatar
Guolin Ke committed
170
171
172
173
174
175
176
177
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
178
  if (num_cat_ > 0) {
179
    if (data->num_features() > num_leaves_ - 1) {
180
181
182
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
183
184
      });
    } else {
185
186
187
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
188
189
      });
    }
Guolin Ke's avatar
Guolin Ke committed
190
  } else {
191
    if (data->num_features() > num_leaves_ - 1) {
192
193
194
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
195
196
      });
    } else {
197
198
199
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
200
201
      });
    }
Guolin Ke's avatar
Guolin Ke committed
202
  }
Guolin Ke's avatar
Guolin Ke committed
203
204
}

205
206
#undef PredictionFun

Guolin Ke's avatar
Guolin Ke committed
207
std::string Tree::ToString() const {
208
  std::stringstream str_buf;
209
210
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
211
  str_buf << "split_feature="
212
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
213
  str_buf << "split_gain="
214
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
215
  str_buf << "threshold="
216
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
217
  str_buf << "decision_type="
218
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
219
  str_buf << "left_child="
220
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
221
  str_buf << "right_child="
222
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
223
  str_buf << "leaf_value="
224
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
225
  str_buf << "leaf_count="
226
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
227
  str_buf << "internal_value="
228
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
229
  str_buf << "internal_count="
230
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
231
232
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
233
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
234
    str_buf << "cat_threshold="
235
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
236
  }
237
238
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
239
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
240
241
}

Guolin Ke's avatar
Guolin Ke committed
242
std::string Tree::ToJSON() const {
243
  std::stringstream str_buf;
244
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
245
246
247
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
248
  if (num_leaves_ == 1) {
249
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
250
  } else {
251
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
252
  }
wxchan's avatar
wxchan committed
253

254
  return str_buf.str();
wxchan's avatar
wxchan committed
255
256
}

Guolin Ke's avatar
Guolin Ke committed
257
std::string Tree::NodeToJSON(int index) const {
258
  std::stringstream str_buf;
259
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
260
261
  if (index >= 0) {
    // non-leaf
262
263
264
265
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
266
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
267
268
269
270
271
272
273
274
275
276
277
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
278
279
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
280
    } else {
281
282
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
283
284
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
285
      str_buf << "\"default_left\":true," << '\n';
286
    } else {
287
      str_buf << "\"default_left\":false," << '\n';
288
289
290
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
291
      str_buf << "\"missing_type\":\"None\"," << '\n';
292
    } else if (missing_type == 1) {
293
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
294
    } else {
295
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
296
    }
297
298
299
300
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
301
    str_buf << "}";
wxchan's avatar
wxchan committed
302
303
304
  } else {
    // leaf
    index = ~index;
305
306
307
308
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
309
    str_buf << "}";
wxchan's avatar
wxchan committed
310
311
  }

312
  return str_buf.str();
wxchan's avatar
wxchan committed
313
314
}

Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
319
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
345
  int cat_idx = static_cast<int>(threshold_[node]);
346
347
348
349
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
350
351
352
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
353
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
354
355
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
356
  if (predict_leaf_index) {
357
358
359
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
360
361
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
362
  } else {
363
364
365
366
367
368
369
370
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
371
372
373
374
375
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
376
    str_buf << NodeToIfElse(0, predict_leaf_index);
377
  }
378
  str_buf << " }" << '\n';
379

380
  // Predict func by Map to ifelse
381
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
382
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
383
384
385
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
386
387
388
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
404
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
405
  }
406
  str_buf << " }" << '\n';
407

408
409
410
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
411
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
412
413
414
415
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
416
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
417
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
418
      str_buf << NumericalDecisionIfElse(index);
419
    } else {
420
      str_buf << CategoricalDecisionIfElse(index);
421
422
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
423
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
424
425
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
426
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
427
428
429
430
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
431
    if (predict_leaf_index) {
432
433
434
435
436
437
438
439
440
441
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
442
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
443
444
445
446
447
448
449
450
451
452
453
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
454
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
455
456
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
457
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
458
459
460
461
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
462
    if (predict_leaf_index) {
463
464
465
466
467
468
469
470
471
472
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

473
474
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
475
  std::unordered_map<std::string, std::string> key_vals;
476
477
478
479
480
  const int max_num_line = 15;
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
481
    while (*p != '=') ++p;
482
483
484
485
486
487
488
489
490
491
492
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

493
  if (key_vals.count("num_leaves") <= 0) {
494
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
495
496
497
498
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

499
  if (key_vals.count("num_cat") <= 0) {
500
    Log::Fatal("Tree model should contain num_cat field");
501
502
503
504
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

505
  if (key_vals.count("leaf_value")) {
506
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
507
508
509
510
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }

511
512
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
513
  if (key_vals.count("left_child")) {
514
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
515
516
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
517
518
  }

Guolin Ke's avatar
Guolin Ke committed
519
  if (key_vals.count("right_child")) {
520
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
521
522
523
524
525
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
526
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
527
528
529
530
531
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
532
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
533
534
535
536
537
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
538
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
543
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
544
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
545
546
547
548
549
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
550
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
551
552
553
554
555
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_count")) {
556
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
557
558
559
560
561
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
562
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
563
564
565
566
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

567
568
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
569
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
570
571
572
573
574
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
575
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
576
    } else {
577
      Log::Fatal("Tree model should contain cat_threshold field");
578
579
580
    }
  }

Guolin Ke's avatar
Guolin Ke committed
581
582
583
584
585
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
586
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
587
588
}

Guolin Ke's avatar
Guolin Ke committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
654
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
655
656
657
658
659
660
661
662
663
664
665
666
667
668
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
669
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
670
671
672
673
674
675
676
677
678
679
680
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
681
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
682
683
684
685
686
687
688
689
690
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
691
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
692
693

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
694
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
695
696
697
  }
}

698
699
700
701
702
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
703
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
704
  }
705
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
706
707
}

708
709
710
711
712
713
714
715
716
717
718
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
719
720
721
  }
}

Guolin Ke's avatar
Guolin Ke committed
722
}  // namespace LightGBM