"python-package/vscode:/vscode.git/clone" did not exist on "0f0dd9d5e6ac48516aad00e59d26e35181cf8430"
tree.cpp 28.3 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
8
9
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
10
11

#include <functional>
12
#include <iomanip>
13
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
29
  leaf_weight_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
30
31
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
32
  internal_weight_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
33
34
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
  num_leaves_ = 1;
38
  leaf_value_[0] = 0.0f;
39
  leaf_weight_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
40
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
41
  shrinkage_ = 1.0f;
42
  num_cat_ = 0;
43
44
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
45
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
46
}
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
49
50
Tree::~Tree() {
}

51
52
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
53
54
                int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type, bool default_left) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
Guolin Ke's avatar
Guolin Ke committed
55
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
56
  decision_type_[new_node_idx] = 0;
57
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
65
  }
Guolin Ke's avatar
Guolin Ke committed
66
  threshold_in_bin_[new_node_idx] = threshold_bin;
Guolin Ke's avatar
Guolin Ke committed
67
  threshold_[new_node_idx] = threshold_double;
68
69
70
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
71

72
73
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
74
75
                           data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
76
77
78
79
80
81
82
83
84
85
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
86
87
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
88
  ++num_cat_;
89
90
91
92
93
94
95
96
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
  ++num_leaves_;
  return num_leaves_ - 1;
}

101
102
103
104
105
106
107
108
109
110
111
112
113
114
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

115
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
132
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
133
  if (num_cat_ > 0) {
134
    if (data->num_features() > num_leaves_ - 1) {
135
136
137
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
138
139
      });
    } else {
140
141
142
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
143
144
      });
    }
Guolin Ke's avatar
Guolin Ke committed
145
  } else {
146
    if (data->num_features() > num_leaves_ - 1) {
147
148
149
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
150
151
      });
    } else {
152
153
154
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
155
156
      });
    }
Guolin Ke's avatar
Guolin Ke committed
157
  }
Guolin Ke's avatar
Guolin Ke committed
158
159
}

Guolin Ke's avatar
Guolin Ke committed
160
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
161
162
163
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
164
165
166
167
168
169
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
170
    return;
Guolin Ke's avatar
Guolin Ke committed
171
  }
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
180
  if (num_cat_ > 0) {
181
    if (data->num_features() > num_leaves_ - 1) {
182
183
184
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
185
186
      });
    } else {
187
188
189
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
190
191
      });
    }
Guolin Ke's avatar
Guolin Ke committed
192
  } else {
193
    if (data->num_features() > num_leaves_ - 1) {
194
195
196
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
197
198
      });
    } else {
199
200
201
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
202
203
      });
    }
Guolin Ke's avatar
Guolin Ke committed
204
  }
Guolin Ke's avatar
Guolin Ke committed
205
206
}

207
208
#undef PredictionFun

209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
double Tree::GetUpperBoundValue() const {
  double upper_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] > upper_bound) {
      upper_bound = leaf_value_[i];
    }
  }
  return upper_bound;
}

double Tree::GetLowerBoundValue() const {
  double lower_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] < lower_bound) {
      lower_bound = leaf_value_[i];
    }
  }
  return lower_bound;
}

Guolin Ke's avatar
Guolin Ke committed
229
std::string Tree::ToString() const {
230
  std::stringstream str_buf;
231
232
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
233
  str_buf << "split_feature="
234
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
235
  str_buf << "split_gain="
236
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
237
  str_buf << "threshold="
238
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
239
  str_buf << "decision_type="
240
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
241
  str_buf << "left_child="
242
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
243
  str_buf << "right_child="
244
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
245
  str_buf << "leaf_value="
246
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
247
248
  str_buf << "leaf_weight="
    << Common::ArrayToString(leaf_weight_, num_leaves_) << '\n';
249
  str_buf << "leaf_count="
250
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
251
  str_buf << "internal_value="
252
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
253
254
  str_buf << "internal_weight="
    << Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n';
255
  str_buf << "internal_count="
256
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
257
258
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
259
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
260
    str_buf << "cat_threshold="
261
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
262
  }
263
264
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
265
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
266
267
}

Guolin Ke's avatar
Guolin Ke committed
268
std::string Tree::ToJSON() const {
269
  std::stringstream str_buf;
270
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
271
272
273
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
274
  if (num_leaves_ == 1) {
275
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
276
  } else {
277
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
278
  }
wxchan's avatar
wxchan committed
279

280
  return str_buf.str();
wxchan's avatar
wxchan committed
281
282
}

Guolin Ke's avatar
Guolin Ke committed
283
std::string Tree::NodeToJSON(int index) const {
284
  std::stringstream str_buf;
285
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
286
287
  if (index >= 0) {
    // non-leaf
288
289
290
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
Guolin Ke's avatar
Guolin Ke committed
291
    str_buf << "\"split_gain\":" << Common::AvoidInf(split_gain_[index]) << "," << '\n';
292
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
293
294
295
296
297
298
299
300
301
302
303
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
304
305
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
306
    } else {
307
308
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
309
310
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
311
      str_buf << "\"default_left\":true," << '\n';
312
    } else {
313
      str_buf << "\"default_left\":false," << '\n';
314
315
316
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
317
      str_buf << "\"missing_type\":\"None\"," << '\n';
318
    } else if (missing_type == 1) {
319
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
320
    } else {
321
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
322
    }
323
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
324
    str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
325
326
327
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
328
    str_buf << "}";
wxchan's avatar
wxchan committed
329
330
331
  } else {
    // leaf
    index = ~index;
332
333
334
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
335
    str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
336
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
337
    str_buf << "}";
wxchan's avatar
wxchan committed
338
339
  }

340
  return str_buf.str();
wxchan's avatar
wxchan committed
341
342
}

Guolin Ke's avatar
Guolin Ke committed
343
344
345
346
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
347
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
373
  int cat_idx = static_cast<int>(threshold_[node]);
374
375
376
377
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
378
379
380
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
381
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
382
383
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
384
  if (predict_leaf_index) {
385
386
387
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
388
389
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
390
  } else {
391
392
393
394
395
396
397
398
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
399
400
401
402
403
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
404
    str_buf << NodeToIfElse(0, predict_leaf_index);
405
  }
406
  str_buf << " }" << '\n';
407

408
  // Predict func by Map to ifelse
409
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
410
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
411
412
413
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
414
415
416
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
432
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
433
  }
434
  str_buf << " }" << '\n';
435

436
437
438
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
439
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
440
441
442
443
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
444
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
445
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
446
      str_buf << NumericalDecisionIfElse(index);
447
    } else {
448
      str_buf << CategoricalDecisionIfElse(index);
449
450
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
451
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
452
453
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
454
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
455
456
457
458
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
459
    if (predict_leaf_index) {
460
461
462
463
464
465
466
467
468
469
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
470
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
471
472
473
474
475
476
477
478
479
480
481
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
482
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
483
484
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
485
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
486
487
488
489
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
490
    if (predict_leaf_index) {
491
492
493
494
495
496
497
498
499
500
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

501
502
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
503
  std::unordered_map<std::string, std::string> key_vals;
504
  const int max_num_line = 17;
505
506
507
508
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
509
    while (*p != '=') ++p;
510
511
512
513
514
515
516
517
518
519
520
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

521
  if (key_vals.count("num_leaves") <= 0) {
522
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
523
524
525
526
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

527
  if (key_vals.count("num_cat") <= 0) {
528
    Log::Fatal("Tree model should contain num_cat field");
529
530
531
532
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

533
  if (key_vals.count("leaf_value")) {
534
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
535
536
537
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }
538

Guolin Ke's avatar
Guolin Ke committed
539
540
541
542
543
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
544

545
546
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
547
  if (key_vals.count("left_child")) {
548
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
549
550
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
551
552
  }

Guolin Ke's avatar
Guolin Ke committed
553
  if (key_vals.count("right_child")) {
554
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
555
556
557
558
559
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
560
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
561
562
563
564
565
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
566
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
567
568
569
570
571
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
572
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
573
574
575
576
577
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
578
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
579
580
581
582
583
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
584
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
585
586
587
588
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

589
590
  if (key_vals.count("internal_weight")) {
    internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
591
  } else {
592
593
594
595
596
    internal_weight_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_weight")) {
    leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
597
  } else {
598
599
600
    leaf_weight_.resize(num_leaves_);
  }

Guolin Ke's avatar
Guolin Ke committed
601
  if (key_vals.count("leaf_count")) {
602
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
603
604
605
606
607
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
608
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
609
610
611
612
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

613
614
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
615
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
616
617
618
619
620
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
621
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
622
    } else {
623
      Log::Fatal("Tree model should contain cat_threshold field");
624
625
    }
  }
626
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
627
628
}

Guolin Ke's avatar
Guolin Ke committed
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
694
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
695
696
697
698
699
700
701
702
703
704
705
706
707
708
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
709
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
710
711
712
713
714
715
716
717
718
719
720
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
721
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
722
723
724
725
726
727
728
729
730
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
731
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
732
733

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
734
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
735
736
737
  }
}

738
739
740
741
742
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
743
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
744
  }
745
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
746
747
}

748
749
750
751
752
753
754
755
756
757
758
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
759
760
761
  }
}

Guolin Ke's avatar
Guolin Ke committed
762
}  // namespace LightGBM