tree.cpp 25.3 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/tree.h>

#include <LightGBM/utils/threading.h>
#include <LightGBM/utils/common.h>

#include <LightGBM/dataset.h>

#include <sstream>
#include <unordered_map>
#include <functional>
#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
13
#include <memory>
14
#include <iomanip>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19
20

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {

Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
25
26
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
27
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
33
34
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
  num_leaves_ = 1;
38
  leaf_value_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
39
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
40
  shrinkage_ = 1.0f;
41
  num_cat_ = 0;
42
43
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
Guolin Ke's avatar
Guolin Ke committed
44
}
Guolin Ke's avatar
Guolin Ke committed
45

Guolin Ke's avatar
Guolin Ke committed
46
Tree::~Tree() {
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
49
}

50
51
52
53
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
                data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type, bool default_left) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
Guolin Ke's avatar
Guolin Ke committed
54
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
55
  decision_type_[new_node_idx] = 0;
56
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
62
63
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
64
  }
Guolin Ke's avatar
Guolin Ke committed
65
  threshold_in_bin_[new_node_idx] = threshold_bin;
66
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
67
68
69
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
70

71
72
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
73
74
75
76
77
78
79
80
81
82
83
84
                           data_size_t left_cnt, data_size_t right_cnt, double gain, MissingType missing_type) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
85
86
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
87
  ++num_cat_;
88
89
90
91
92
93
94
95
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
  ++num_leaves_;
  return num_leaves_ - 1;
}

100
101
102
103
104
105
106
107
108
109
110
111
112
113
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

114
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
120
121
122
123
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
128
129
130
131
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
132
  if (num_cat_ > 0) {
133
    if (data->num_features() > num_leaves_ - 1) {
134
135
136
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
137
138
      });
    } else {
139
140
141
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
142
143
      });
    }
Guolin Ke's avatar
Guolin Ke committed
144
  } else {
145
    if (data->num_features() > num_leaves_ - 1) {
146
147
148
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
149
150
      });
    } else {
151
152
153
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
154
155
      });
    }
Guolin Ke's avatar
Guolin Ke committed
156
  }
Guolin Ke's avatar
Guolin Ke committed
157
158
}

Guolin Ke's avatar
Guolin Ke committed
159
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
160
161
162
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
163
164
165
166
167
168
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
169
    return;
Guolin Ke's avatar
Guolin Ke committed
170
  }
Guolin Ke's avatar
Guolin Ke committed
171
172
173
174
175
176
177
178
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
179
  if (num_cat_ > 0) {
180
    if (data->num_features() > num_leaves_ - 1) {
181
182
183
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
184
185
      });
    } else {
186
187
188
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
189
190
      });
    }
Guolin Ke's avatar
Guolin Ke committed
191
  } else {
192
    if (data->num_features() > num_leaves_ - 1) {
193
194
195
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
196
197
      });
    } else {
198
199
200
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
201
202
      });
    }
Guolin Ke's avatar
Guolin Ke committed
203
  }
Guolin Ke's avatar
Guolin Ke committed
204
205
}

206
207
#undef PredictionFun

Guolin Ke's avatar
Guolin Ke committed
208
std::string Tree::ToString() const {
209
210
  std::stringstream str_buf;
  str_buf << "num_leaves=" << num_leaves_ << std::endl;
211
  str_buf << "num_cat=" << num_cat_ << std::endl;
212
  str_buf << "split_feature="
Guolin Ke's avatar
Guolin Ke committed
213
    << Common::ArrayToString<int>(split_feature_, num_leaves_ - 1, ' ') << std::endl;
214
  str_buf << "split_gain="
Guolin Ke's avatar
Guolin Ke committed
215
    << Common::ArrayToString<double>(split_gain_, num_leaves_ - 1, ' ') << std::endl;
216
  str_buf << "threshold="
Guolin Ke's avatar
Guolin Ke committed
217
    << Common::ArrayToString<double>(threshold_, num_leaves_ - 1, ' ') << std::endl;
218
219
  str_buf << "decision_type="
    << Common::ArrayToString<int>(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1, ' ') << std::endl;
220
  str_buf << "left_child="
Guolin Ke's avatar
Guolin Ke committed
221
    << Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl;
222
  str_buf << "right_child="
Guolin Ke's avatar
Guolin Ke committed
223
    << Common::ArrayToString<int>(right_child_, num_leaves_ - 1, ' ') << std::endl;
224
  str_buf << "leaf_value="
Guolin Ke's avatar
Guolin Ke committed
225
    << Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl;
226
  str_buf << "leaf_count="
Guolin Ke's avatar
Guolin Ke committed
227
    << Common::ArrayToString<data_size_t>(leaf_count_, num_leaves_, ' ') << std::endl;
228
  str_buf << "internal_value="
Guolin Ke's avatar
Guolin Ke committed
229
    << Common::ArrayToString<double>(internal_value_, num_leaves_ - 1, ' ') << std::endl;
230
  str_buf << "internal_count="
Guolin Ke's avatar
Guolin Ke committed
231
    << Common::ArrayToString<data_size_t>(internal_count_, num_leaves_ - 1, ' ') << std::endl;
232
233
234
235
236
237
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
      << Common::ArrayToString<int>(cat_boundaries_, num_cat_ + 1, ' ') << std::endl;
    str_buf << "cat_threshold="
      << Common::ArrayToString<uint32_t>(cat_threshold_, cat_threshold_.size(), ' ') << std::endl;
  }
Guolin Ke's avatar
Guolin Ke committed
238
  str_buf << "shrinkage=" << shrinkage_ << std::endl;
239
240
  str_buf << std::endl;
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
241
242
}

Guolin Ke's avatar
Guolin Ke committed
243
std::string Tree::ToJSON() const {
244
  std::stringstream str_buf;
245
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
246
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl;
247
  str_buf << "\"num_cat\":" << num_cat_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
248
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl;
wxchan's avatar
wxchan committed
249
  if (num_leaves_ == 1) {
Guolin Ke's avatar
Guolin Ke committed
250
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << std::endl;
wxchan's avatar
wxchan committed
251
252
253
  } else {
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl;
  }
wxchan's avatar
wxchan committed
254

255
  return str_buf.str();
wxchan's avatar
wxchan committed
256
257
}

Guolin Ke's avatar
Guolin Ke committed
258
std::string Tree::NodeToJSON(int index) const {
259
  std::stringstream str_buf;
260
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
261
262
  if (index >= 0) {
    // non-leaf
263
264
    str_buf << "{" << std::endl;
    str_buf << "\"split_index\":" << index << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
265
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << std::endl;
266
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << std::endl;
267
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
268
269
270
271
272
273
274
275
276
277
278
279
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << std::endl;
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
      str_buf << "\"decision_type\":\"==\"," << std::endl;
    } else {
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << std::endl;
      str_buf << "\"decision_type\":\"<=\"," << std::endl;
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
      str_buf << "\"default_left\":true," << std::endl;
    } else {
      str_buf << "\"default_left\":false," << std::endl;
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
      str_buf << "\"missing_type\":\"None\"," << std::endl;
    } else if (missing_type == 1) {
      str_buf << "\"missing_type\":\"Zero\"," << std::endl;
    } else {
      str_buf << "\"missing_type\":\"NaN\"," << std::endl;
    }
298
299
300
301
302
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << std::endl;
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << std::endl;
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << std::endl;
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << std::endl;
    str_buf << "}";
wxchan's avatar
wxchan committed
303
304
305
  } else {
    // leaf
    index = ~index;
306
307
308
309
310
    str_buf << "{" << std::endl;
    str_buf << "\"leaf_index\":" << index << "," << std::endl;
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << std::endl;
    str_buf << "\"leaf_count\":" << leaf_count_[index] << std::endl;
    str_buf << "}";
wxchan's avatar
wxchan committed
311
312
  }

313
  return str_buf.str();
wxchan's avatar
wxchan committed
314
315
}

Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroAsMissingValueRange < threshold_[node])) {
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
346
347
348
349
350
  int cat_idx = int(threshold_[node]);
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
351
352
353
354
  return str_buf.str();
}

std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) const {
355
356
357
358
359
360
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
  if (is_predict_leaf_index) {
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
361
362
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
363
  } else {
364
365
366
367
368
369
370
371
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
372
373
374
375
376
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
377
378
379
380
381
382
    str_buf << NodeToIfElse(0, is_predict_leaf_index);
  }
  str_buf << " }" << std::endl;
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
383
std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) const {
384
385
386
387
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
388
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
389
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
390
      str_buf << NumericalDecisionIfElse(index);
391
    } else {
392
      str_buf << CategoricalDecisionIfElse(index);
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    }
    // left subtree
    str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index);
    str_buf << " } else { ";
    // right subtree
    str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index);
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
    if (is_predict_leaf_index) {
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
414
Tree::Tree(const std::string& str) {
Guolin Ke's avatar
Guolin Ke committed
415
  std::vector<std::string> lines = Common::SplitLines(str.c_str());
Guolin Ke's avatar
Guolin Ke committed
416
417
418
419
420
421
422
423
424
425
426
  std::unordered_map<std::string, std::string> key_vals;
  for (const std::string& line : lines) {
    std::vector<std::string> tmp_strs = Common::Split(line.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::Trim(tmp_strs[0]);
      std::string val = Common::Trim(tmp_strs[1]);
      if (key.size() > 0 && val.size() > 0) {
        key_vals[key] = val;
      }
    }
  }
427
  if (key_vals.count("num_leaves") <= 0) {
Guolin Ke's avatar
Guolin Ke committed
428
    Log::Fatal("Tree model should contain num_leaves field.");
Guolin Ke's avatar
Guolin Ke committed
429
430
431
432
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

433
434
435
436
437
438
  if (key_vals.count("num_cat") <= 0) {
    Log::Fatal("Tree model should contain num_cat field.");
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

439
440
441
442
443
444
  if (key_vals.count("leaf_value")) {
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }

445
446
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
447
448
449
450
  if (key_vals.count("left_child")) {
    left_child_ = Common::StringToArray<int>(key_vals["left_child"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
451
452
  }

Guolin Ke's avatar
Guolin Ke committed
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
  if (key_vals.count("right_child")) {
    right_child_ = Common::StringToArray<int>(key_vals["right_child"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
    split_feature_ = Common::StringToArray<int>(key_vals["split_feature"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
    split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
    internal_count_ = Common::StringToArray<data_size_t>(key_vals["internal_count"], ' ', num_leaves_ - 1);
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
    internal_value_ = Common::StringToArray<double>(key_vals["internal_value"], ' ', num_leaves_ - 1);
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_count")) {
    leaf_count_ = Common::StringToArray<data_size_t>(key_vals["leaf_count"], ' ', num_leaves_);
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
    decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1);
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

501
502
503
504
505
506
507
508
509
510
511
512
513
514
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
      cat_boundaries_ = Common::StringToArray<int>(key_vals["cat_boundaries"], ' ', num_cat_ + 1);
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
      cat_threshold_ = Common::StringToArray<uint32_t>(key_vals["cat_threshold"], ' ', cat_boundaries_.back());
    } else {
      Log::Fatal("Tree model should contain cat_threshold field.");
    }
  }

Guolin Ke's avatar
Guolin Ke committed
515
516
517
518
519
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
Guolin Ke's avatar
Guolin Ke committed
520
521
}

Guolin Ke's avatar
Guolin Ke committed
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {

  // extend the unique path
  PathElement *unique_path = parent_unique_path + unique_depth;
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
603
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
604
605
606
607
608
609
610
611
612
613
614
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
615
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
616
617
618
619
620
621
622
623
624
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
625
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
626
627

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
628
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
629
630
631
  }
}

632
633
634
635
636
637
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
    exp_value += (leaf_count_[i]/total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
638
  }
639
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
640
641
}

642
643
644
int Tree::MaxDepth() {
  if (leaf_depth_.size() == 0) RecomputeLeafDepths();
  if (num_leaves_ == 1) return 0;
Guolin Ke's avatar
Guolin Ke committed
645
646
647
648
649
650
651
  int max_depth = 0;
  for (int i = 0; i < num_leaves(); ++i) {
    if (max_depth < leaf_depth_[i]) max_depth = leaf_depth_[i];
  }
  return max_depth;
}

Guolin Ke's avatar
Guolin Ke committed
652
}  // namespace LightGBM