tree.cpp 26.8 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/tree.h>

#include <LightGBM/utils/threading.h>
#include <LightGBM/utils/common.h>

#include <LightGBM/dataset.h>

#include <sstream>
#include <unordered_map>
#include <functional>
#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
13
#include <memory>
14
#include <iomanip>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19
20

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {

Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
25
26
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
27
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
33
34
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
  num_leaves_ = 1;
38
  leaf_value_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
39
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
40
  shrinkage_ = 1.0f;
41
  num_cat_ = 0;
42
43
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
44
}
Guolin Ke's avatar
Guolin Ke committed
45

Guolin Ke's avatar
Guolin Ke committed
46
Tree::~Tree() {
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
49
}

50
51
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
52
                int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
53
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
Guolin Ke's avatar
Guolin Ke committed
54
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
55
  decision_type_[new_node_idx] = 0;
56
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
62
63
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
64
  }
Guolin Ke's avatar
Guolin Ke committed
65
  threshold_in_bin_[new_node_idx] = threshold_bin;
66
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
67
68
69
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
70

71
72
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
73
                           data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
74
75
76
77
78
79
80
81
82
83
84
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
85
86
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
87
  ++num_cat_;
88
89
90
91
92
93
94
95
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
  ++num_leaves_;
  return num_leaves_ - 1;
}

100
101
102
103
104
105
106
107
108
109
110
111
112
113
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

114
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
128
129
130
131
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
132
  if (num_cat_ > 0) {
133
    if (data->num_features() > num_leaves_ - 1) {
134
135
136
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
137
138
      });
    } else {
139
140
141
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
142
143
      });
    }
Guolin Ke's avatar
Guolin Ke committed
144
  } else {
145
    if (data->num_features() > num_leaves_ - 1) {
146
147
148
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
149
150
      });
    } else {
151
152
153
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
154
155
      });
    }
Guolin Ke's avatar
Guolin Ke committed
156
  }
Guolin Ke's avatar
Guolin Ke committed
157
158
}

Guolin Ke's avatar
Guolin Ke committed
159
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
160
161
162
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
163
164
165
166
167
168
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
169
    return;
Guolin Ke's avatar
Guolin Ke committed
170
  }
Guolin Ke's avatar
Guolin Ke committed
171
172
173
174
175
176
177
178
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
179
  if (num_cat_ > 0) {
180
    if (data->num_features() > num_leaves_ - 1) {
181
182
183
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
184
185
      });
    } else {
186
187
188
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
189
190
      });
    }
Guolin Ke's avatar
Guolin Ke committed
191
  } else {
192
    if (data->num_features() > num_leaves_ - 1) {
193
194
195
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
196
197
      });
    } else {
198
199
200
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
201
202
      });
    }
Guolin Ke's avatar
Guolin Ke committed
203
  }
Guolin Ke's avatar
Guolin Ke committed
204
205
}

206
207
#undef PredictionFun

Guolin Ke's avatar
Guolin Ke committed
208
std::string Tree::ToString() const {
209
  std::stringstream str_buf;
210
211
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
212
  str_buf << "split_feature="
213
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
214
  str_buf << "split_gain="
215
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
216
  str_buf << "threshold="
217
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
218
  str_buf << "decision_type="
219
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
220
  str_buf << "left_child="
221
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
222
  str_buf << "right_child="
223
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
224
  str_buf << "leaf_value="
225
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
226
  str_buf << "leaf_count="
227
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
228
  str_buf << "internal_value="
229
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
230
  str_buf << "internal_count="
231
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
232
233
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
234
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
235
    str_buf << "cat_threshold="
236
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
237
  }
238
239
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
240
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
241
242
}

Guolin Ke's avatar
Guolin Ke committed
243
std::string Tree::ToJSON() const {
244
  std::stringstream str_buf;
245
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
246
247
248
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
249
  if (num_leaves_ == 1) {
250
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
251
  } else {
252
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
253
  }
wxchan's avatar
wxchan committed
254

255
  return str_buf.str();
wxchan's avatar
wxchan committed
256
257
}

Guolin Ke's avatar
Guolin Ke committed
258
std::string Tree::NodeToJSON(int index) const {
259
  std::stringstream str_buf;
260
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
261
262
  if (index >= 0) {
    // non-leaf
263
264
265
266
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
267
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
268
269
270
271
272
273
274
275
276
277
278
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
279
280
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
281
    } else {
282
283
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
284
285
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
286
      str_buf << "\"default_left\":true," << '\n';
287
    } else {
288
      str_buf << "\"default_left\":false," << '\n';
289
290
291
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
292
      str_buf << "\"missing_type\":\"None\"," << '\n';
293
    } else if (missing_type == 1) {
294
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
295
    } else {
296
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
297
    }
298
299
300
301
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
302
    str_buf << "}";
wxchan's avatar
wxchan committed
303
304
305
  } else {
    // leaf
    index = ~index;
306
307
308
309
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
310
    str_buf << "}";
wxchan's avatar
wxchan committed
311
312
  }

313
  return str_buf.str();
wxchan's avatar
wxchan committed
314
315
}

Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
320
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
346
347
348
349
350
  int cat_idx = int(threshold_[node]);
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
351
352
353
354
  return str_buf.str();
}

std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
355
356
357
358
359
360
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
  if (is_predict_leaf_index) {
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
361
362
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
363
  } else {
364
365
366
367
368
369
370
371
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
372
373
374
375
376
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
377
378
    str_buf << NodeToIfElse(0, is_predict_leaf_index);
  }
379
  str_buf << " }" << '\n';
380
381
382
383

  //Predict func by Map to ifelse
  str_buf << "double PredictTree" << index;
  if (is_predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
384
385
386
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
387
388
389
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
    str_buf << NodeToIfElseByMap(0, is_predict_leaf_index);
406
  }
407
  str_buf << " }" << '\n';
408

409
410
411
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
412
std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const {
413
414
415
416
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
417
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
418
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
419
      str_buf << NumericalDecisionIfElse(index);
420
    } else {
421
      str_buf << CategoricalDecisionIfElse(index);
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    }
    // left subtree
    str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index);
    str_buf << " } else { ";
    // right subtree
    str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index);
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
    if (is_predict_leaf_index) {
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
std::string Tree::NodeToIfElseByMap(int index, bool is_predict_leaf_index) const {
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
    str_buf << NodeToIfElseByMap(left_child_[index], is_predict_leaf_index);
    str_buf << " } else { ";
    // right subtree
    str_buf << NodeToIfElseByMap(right_child_[index], is_predict_leaf_index);
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
    if (is_predict_leaf_index) {
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

474
475
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
476
  std::unordered_map<std::string, std::string> key_vals;
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
  const int max_num_line = 15;
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
    while (*p != '=') ++p; 
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

494
  if (key_vals.count("num_leaves") <= 0) {
Guolin Ke's avatar
Guolin Ke committed
495
    Log::Fatal("Tree model should contain num_leaves field.");
Guolin Ke's avatar
Guolin Ke committed
496
497
498
499
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

500
501
502
503
504
505
  if (key_vals.count("num_cat") <= 0) {
    Log::Fatal("Tree model should contain num_cat field.");
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

506
  if (key_vals.count("leaf_value")) {
507
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
508
509
510
511
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }

512
513
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
514
  if (key_vals.count("left_child")) {
515
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
516
517
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
518
519
  }

Guolin Ke's avatar
Guolin Ke committed
520
  if (key_vals.count("right_child")) {
521
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
522
523
524
525
526
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
527
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
528
529
530
531
532
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
533
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
534
535
536
537
538
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
539
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
540
541
542
543
544
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
545
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
546
547
548
549
550
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
551
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
552
553
554
555
556
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_count")) {
557
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
558
559
560
561
562
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
563
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
564
565
566
567
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

568
569
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
570
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
571
572
573
574
575
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
576
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
577
578
579
580
581
    } else {
      Log::Fatal("Tree model should contain cat_threshold field.");
    }
  }

Guolin Ke's avatar
Guolin Ke committed
582
583
584
585
586
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
Guolin Ke's avatar
Guolin Ke committed
587
588
}

Guolin Ke's avatar
Guolin Ke committed
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {

  // extend the unique path
  PathElement *unique_path = parent_unique_path + unique_depth;
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
670
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
671
672
673
674
675
676
677
678
679
680
681
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
682
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
683
684
685
686
687
688
689
690
691
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
692
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
693
694

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
695
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
696
697
698
  }
}

699
700
701
702
703
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
704
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
705
  }
706
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
707
708
}

709
710
711
int Tree::MaxDepth() {
  if (leaf_depth_.size() == 0) RecomputeLeafDepths();
  if (num_leaves_ == 1) return 0;
Guolin Ke's avatar
Guolin Ke committed
712
713
714
715
716
717
718
  int max_depth = 0;
  for (int i = 0; i < num_leaves(); ++i) {
    if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
  }
  return max_depth;
}

Guolin Ke's avatar
Guolin Ke committed
719
}  // namespace LightGBM