tree.cpp 26.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/tree.h>

#include <LightGBM/utils/threading.h>
#include <LightGBM/utils/common.h>

#include <LightGBM/dataset.h>

#include <sstream>
#include <unordered_map>
#include <functional>
#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
13
#include <memory>
14
#include <iomanip>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19
20

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {

Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
25
26
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
27
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
33
34
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
  num_leaves_ = 1;
38
  leaf_value_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
39
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
40
  shrinkage_ = 1.0f;
41
  num_cat_ = 0;
42
43
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
44
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
45
}
Guolin Ke's avatar
Guolin Ke committed
46

Guolin Ke's avatar
Guolin Ke committed
47
Tree::~Tree() {
Guolin Ke's avatar
Guolin Ke committed
48

Guolin Ke's avatar
Guolin Ke committed
49
50
}

51
52
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
53
                int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
54
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
Guolin Ke's avatar
Guolin Ke committed
55
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
56
  decision_type_[new_node_idx] = 0;
57
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
65
  }
Guolin Ke's avatar
Guolin Ke committed
66
  threshold_in_bin_[new_node_idx] = threshold_bin;
67
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
68
69
70
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
71

72
73
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
74
                           data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
75
76
77
78
79
80
81
82
83
84
85
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
86
87
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
88
  ++num_cat_;
89
90
91
92
93
94
95
96
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
  ++num_leaves_;
  return num_leaves_ - 1;
}

101
102
103
104
105
106
107
108
109
110
111
112
113
114
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

115
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
132
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
133
  if (num_cat_ > 0) {
134
    if (data->num_features() > num_leaves_ - 1) {
135
136
137
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
138
139
      });
    } else {
140
141
142
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
143
144
      });
    }
Guolin Ke's avatar
Guolin Ke committed
145
  } else {
146
    if (data->num_features() > num_leaves_ - 1) {
147
148
149
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
150
151
      });
    } else {
152
153
154
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
155
156
      });
    }
Guolin Ke's avatar
Guolin Ke committed
157
  }
Guolin Ke's avatar
Guolin Ke committed
158
159
}

Guolin Ke's avatar
Guolin Ke committed
160
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
161
162
163
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
164
165
166
167
168
169
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
170
    return;
Guolin Ke's avatar
Guolin Ke committed
171
  }
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
180
  if (num_cat_ > 0) {
181
    if (data->num_features() > num_leaves_ - 1) {
182
183
184
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
185
186
      });
    } else {
187
188
189
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
190
191
      });
    }
Guolin Ke's avatar
Guolin Ke committed
192
  } else {
193
    if (data->num_features() > num_leaves_ - 1) {
194
195
196
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
197
198
      });
    } else {
199
200
201
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
202
203
      });
    }
Guolin Ke's avatar
Guolin Ke committed
204
  }
Guolin Ke's avatar
Guolin Ke committed
205
206
}

207
208
#undef PredictionFun

Guolin Ke's avatar
Guolin Ke committed
209
std::string Tree::ToString() const {
210
  std::stringstream str_buf;
211
212
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
213
  str_buf << "split_feature="
214
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
215
  str_buf << "split_gain="
216
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
217
  str_buf << "threshold="
218
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
219
  str_buf << "decision_type="
220
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
221
  str_buf << "left_child="
222
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
223
  str_buf << "right_child="
224
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
225
  str_buf << "leaf_value="
226
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
227
  str_buf << "leaf_count="
228
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
229
  str_buf << "internal_value="
230
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
231
  str_buf << "internal_count="
232
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
233
234
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
235
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
236
    str_buf << "cat_threshold="
237
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
238
  }
239
240
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
241
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
242
243
}

Guolin Ke's avatar
Guolin Ke committed
244
std::string Tree::ToJSON() const {
245
  std::stringstream str_buf;
246
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
247
248
249
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
250
  if (num_leaves_ == 1) {
251
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
252
  } else {
253
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
254
  }
wxchan's avatar
wxchan committed
255

256
  return str_buf.str();
wxchan's avatar
wxchan committed
257
258
}

Guolin Ke's avatar
Guolin Ke committed
259
std::string Tree::NodeToJSON(int index) const {
260
  std::stringstream str_buf;
261
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
262
263
  if (index >= 0) {
    // non-leaf
264
265
266
267
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
268
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
269
270
271
272
273
274
275
276
277
278
279
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
280
281
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
282
    } else {
283
284
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
285
286
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
287
      str_buf << "\"default_left\":true," << '\n';
288
    } else {
289
      str_buf << "\"default_left\":false," << '\n';
290
291
292
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
293
      str_buf << "\"missing_type\":\"None\"," << '\n';
294
    } else if (missing_type == 1) {
295
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
296
    } else {
297
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
298
    }
299
300
301
302
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
303
    str_buf << "}";
wxchan's avatar
wxchan committed
304
305
306
  } else {
    // leaf
    index = ~index;
307
308
309
310
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
311
    str_buf << "}";
wxchan's avatar
wxchan committed
312
313
  }

314
  return str_buf.str();
wxchan's avatar
wxchan committed
315
316
}

Guolin Ke's avatar
Guolin Ke committed
317
318
319
320
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
321
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
347
348
349
350
351
  int cat_idx = int(threshold_[node]);
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
352
353
354
355
  return str_buf.str();
}

std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
356
357
358
359
360
361
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
  if (is_predict_leaf_index) {
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
362
363
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
364
  } else {
365
366
367
368
369
370
371
372
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
373
374
375
376
377
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
378
379
    str_buf << NodeToIfElse(0, is_predict_leaf_index);
  }
380
  str_buf << " }" << '\n';
381
382
383
384

  //Predict func by Map to ifelse
  str_buf << "double PredictTree" << index;
  if (is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
385
386
387
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
388
389
390
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
    str_buf << NodeToIfElseByMap(0, is_predict_leaf_index);
407
  }
408
  str_buf << " }" << '\n';
409

410
411
412
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
413
std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const {
414
415
416
417
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
418
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
419
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
420
      str_buf << NumericalDecisionIfElse(index);
421
    } else {
422
      str_buf << CategoricalDecisionIfElse(index);
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
    }
    // left subtree
    str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index);
    str_buf << " } else { ";
    // right subtree
    str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index);
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
    if (is_predict_leaf_index) {
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
std::string Tree::NodeToIfElseByMap(int index, bool is_predict_leaf_index) const {
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
    str_buf << NodeToIfElseByMap(left_child_[index], is_predict_leaf_index);
    str_buf << " } else { ";
    // right subtree
    str_buf << NodeToIfElseByMap(right_child_[index], is_predict_leaf_index);
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
    if (is_predict_leaf_index) {
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

475
476
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
477
  std::unordered_map<std::string, std::string> key_vals;
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
  const int max_num_line = 15;
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
    while (*p != '=') ++p; 
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

495
  if (key_vals.count("num_leaves") <= 0) {
496
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
497
498
499
500
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

501
  if (key_vals.count("num_cat") <= 0) {
502
    Log::Fatal("Tree model should contain num_cat field");
503
504
505
506
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

507
  if (key_vals.count("leaf_value")) {
508
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
509
510
511
512
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }

513
514
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
515
  if (key_vals.count("left_child")) {
516
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
517
518
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
519
520
  }

Guolin Ke's avatar
Guolin Ke committed
521
  if (key_vals.count("right_child")) {
522
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
523
524
525
526
527
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
528
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
529
530
531
532
533
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
534
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
535
536
537
538
539
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
540
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
541
542
543
544
545
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
546
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
547
548
549
550
551
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
552
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
553
554
555
556
557
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_count")) {
558
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
559
560
561
562
563
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
564
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
565
566
567
568
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

569
570
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
571
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
572
573
574
575
576
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
577
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
578
    } else {
579
      Log::Fatal("Tree model should contain cat_threshold field");
580
581
582
    }
  }

Guolin Ke's avatar
Guolin Ke committed
583
584
585
586
587
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
588
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
589
590
}

Guolin Ke's avatar
Guolin Ke committed
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {

  // extend the unique path
657
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
672
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
673
674
675
676
677
678
679
680
681
682
683
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
684
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
685
686
687
688
689
690
691
692
693
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
694
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
695
696

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
697
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
698
699
700
  }
}

701
702
703
704
705
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
706
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
707
  }
708
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
709
710
}

711
712
713
714
715
716
717
718
719
720
721
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
722
723
724
  }
}

Guolin Ke's avatar
Guolin Ke committed
725
}  // namespace LightGBM