tree.cpp 26.7 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
4
5
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
6
7

#include <functional>
8
#include <iomanip>
9
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19
20
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
21
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
27
28
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
29
30
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
31
  num_leaves_ = 1;
32
  leaf_value_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
33
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
34
  shrinkage_ = 1.0f;
35
  num_cat_ = 0;
36
37
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
38
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
39
}
Guolin Ke's avatar
Guolin Ke committed
40

Guolin Ke's avatar
Guolin Ke committed
41
42
43
Tree::~Tree() {
}

44
45
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
46
                int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
47
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
Guolin Ke's avatar
Guolin Ke committed
48
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
49
  decision_type_[new_node_idx] = 0;
50
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
56
57
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
58
  }
Guolin Ke's avatar
Guolin Ke committed
59
  threshold_in_bin_[new_node_idx] = threshold_bin;
60
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
61
62
63
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
64

65
66
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
67
                           data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
68
69
70
71
72
73
74
75
76
77
78
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
79
80
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
81
  ++num_cat_;
82
83
84
85
86
87
88
89
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
  ++num_leaves_;
  return num_leaves_ - 1;
}

94
95
96
97
98
99
100
101
102
103
104
105
106
107
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

108
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
116
117
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
118
119
120
121
122
123
124
125
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
126
  if (num_cat_ > 0) {
127
    if (data->num_features() > num_leaves_ - 1) {
128
129
130
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
131
132
      });
    } else {
133
134
135
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
136
137
      });
    }
Guolin Ke's avatar
Guolin Ke committed
138
  } else {
139
    if (data->num_features() > num_leaves_ - 1) {
140
141
142
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
143
144
      });
    } else {
145
146
147
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
148
149
      });
    }
Guolin Ke's avatar
Guolin Ke committed
150
  }
Guolin Ke's avatar
Guolin Ke committed
151
152
}

Guolin Ke's avatar
Guolin Ke committed
153
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
154
155
156
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
157
158
159
160
161
162
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
163
    return;
Guolin Ke's avatar
Guolin Ke committed
164
  }
Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
170
171
172
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
173
  if (num_cat_ > 0) {
174
    if (data->num_features() > num_leaves_ - 1) {
175
176
177
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
178
179
      });
    } else {
180
181
182
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
183
184
      });
    }
Guolin Ke's avatar
Guolin Ke committed
185
  } else {
186
    if (data->num_features() > num_leaves_ - 1) {
187
188
189
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
190
191
      });
    } else {
192
193
194
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
195
196
      });
    }
Guolin Ke's avatar
Guolin Ke committed
197
  }
Guolin Ke's avatar
Guolin Ke committed
198
199
}

200
201
#undef PredictionFun

Guolin Ke's avatar
Guolin Ke committed
202
std::string Tree::ToString() const {
203
  std::stringstream str_buf;
204
205
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
206
  str_buf << "split_feature="
207
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
208
  str_buf << "split_gain="
209
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
210
  str_buf << "threshold="
211
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
212
  str_buf << "decision_type="
213
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
214
  str_buf << "left_child="
215
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
216
  str_buf << "right_child="
217
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
218
  str_buf << "leaf_value="
219
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
220
  str_buf << "leaf_count="
221
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
222
  str_buf << "internal_value="
223
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
224
  str_buf << "internal_count="
225
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
226
227
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
228
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
229
    str_buf << "cat_threshold="
230
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
231
  }
232
233
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
234
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
235
236
}

Guolin Ke's avatar
Guolin Ke committed
237
std::string Tree::ToJSON() const {
238
  std::stringstream str_buf;
239
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
240
241
242
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
243
  if (num_leaves_ == 1) {
244
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
245
  } else {
246
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
247
  }
wxchan's avatar
wxchan committed
248

249
  return str_buf.str();
wxchan's avatar
wxchan committed
250
251
}

Guolin Ke's avatar
Guolin Ke committed
252
std::string Tree::NodeToJSON(int index) const {
253
  std::stringstream str_buf;
254
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
255
256
  if (index >= 0) {
    // non-leaf
257
258
259
260
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
261
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
262
263
264
265
266
267
268
269
270
271
272
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
273
274
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
275
    } else {
276
277
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
278
279
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
280
      str_buf << "\"default_left\":true," << '\n';
281
    } else {
282
      str_buf << "\"default_left\":false," << '\n';
283
284
285
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
286
      str_buf << "\"missing_type\":\"None\"," << '\n';
287
    } else if (missing_type == 1) {
288
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
289
    } else {
290
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
291
    }
292
293
294
295
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
296
    str_buf << "}";
wxchan's avatar
wxchan committed
297
298
299
  } else {
    // leaf
    index = ~index;
300
301
302
303
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
304
    str_buf << "}";
wxchan's avatar
wxchan committed
305
306
  }

307
  return str_buf.str();
wxchan's avatar
wxchan committed
308
309
}

Guolin Ke's avatar
Guolin Ke committed
310
311
312
313
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
314
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
340
  int cat_idx = static_cast<int>(threshold_[node]);
341
342
343
344
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
345
346
347
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
348
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
349
350
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
351
  if (predict_leaf_index) {
352
353
354
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
355
356
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
357
  } else {
358
359
360
361
362
363
364
365
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
366
367
368
369
370
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
371
    str_buf << NodeToIfElse(0, predict_leaf_index);
372
  }
373
  str_buf << " }" << '\n';
374

375
  // Predict func by Map to ifelse
376
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
377
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
378
379
380
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
381
382
383
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
399
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
400
  }
401
  str_buf << " }" << '\n';
402

403
404
405
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
406
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
407
408
409
410
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
411
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
412
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
413
      str_buf << NumericalDecisionIfElse(index);
414
    } else {
415
      str_buf << CategoricalDecisionIfElse(index);
416
417
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
418
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
419
420
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
421
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
422
423
424
425
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
426
    if (predict_leaf_index) {
427
428
429
430
431
432
433
434
435
436
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
437
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
438
439
440
441
442
443
444
445
446
447
448
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
449
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
450
451
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
452
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
453
454
455
456
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
457
    if (predict_leaf_index) {
458
459
460
461
462
463
464
465
466
467
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

468
469
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
470
  std::unordered_map<std::string, std::string> key_vals;
471
472
473
474
475
  const int max_num_line = 15;
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
476
    while (*p != '=') ++p;
477
478
479
480
481
482
483
484
485
486
487
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

488
  if (key_vals.count("num_leaves") <= 0) {
489
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
490
491
492
493
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

494
  if (key_vals.count("num_cat") <= 0) {
495
    Log::Fatal("Tree model should contain num_cat field");
496
497
498
499
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

500
  if (key_vals.count("leaf_value")) {
501
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
502
503
504
505
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }

506
507
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
508
  if (key_vals.count("left_child")) {
509
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
510
511
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
512
513
  }

Guolin Ke's avatar
Guolin Ke committed
514
  if (key_vals.count("right_child")) {
515
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
521
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
522
523
524
525
526
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
527
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
528
529
530
531
532
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
533
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
534
535
536
537
538
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
539
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
540
541
542
543
544
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
545
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
546
547
548
549
550
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_count")) {
551
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
552
553
554
555
556
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
557
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
558
559
560
561
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

562
563
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
564
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
565
566
567
568
569
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
570
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
571
    } else {
572
      Log::Fatal("Tree model should contain cat_threshold field");
573
574
575
    }
  }

Guolin Ke's avatar
Guolin Ke committed
576
577
578
579
580
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
581
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
582
583
}

Guolin Ke's avatar
Guolin Ke committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
649
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
664
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
665
666
667
668
669
670
671
672
673
674
675
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
676
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
677
678
679
680
681
682
683
684
685
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
686
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
687
688

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
689
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
690
691
692
  }
}

693
694
695
696
697
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
698
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
699
  }
700
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
701
702
}

703
704
705
706
707
708
709
710
711
712
713
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
714
715
716
  }
}

Guolin Ke's avatar
Guolin Ke committed
717
}  // namespace LightGBM