tree.cpp 31.6 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
8
9
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
10

11
12
13
14
#include <functional>
#include <iomanip>
#include <sstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
Tree::Tree(int max_leaves, bool track_branch_features)
  :max_leaves_(max_leaves), track_branch_features_(track_branch_features) {
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
29
  leaf_weight_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
30
31
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
32
  internal_weight_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
33
34
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
35
36
37
  if (track_branch_features_) {
    branch_features_ = std::vector<std::vector<int>>(max_leaves_);
  }
Guolin Ke's avatar
Guolin Ke committed
38
39
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
40
  num_leaves_ = 1;
41
  leaf_value_[0] = 0.0f;
42
  leaf_weight_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
43
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
44
  shrinkage_ = 1.0f;
45
  num_cat_ = 0;
46
47
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
48
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
49
}
Guolin Ke's avatar
Guolin Ke committed
50

51
52
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
53
54
                int left_cnt, int right_cnt, double left_weight, double right_weight, float gain,
                MissingType missing_type, bool default_left) {
55
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
Guolin Ke's avatar
Guolin Ke committed
56
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
57
  decision_type_[new_node_idx] = 0;
58
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
59
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
60
  SetMissingType(&decision_type_[new_node_idx], static_cast<int8_t>(missing_type));
Guolin Ke's avatar
Guolin Ke committed
61
  threshold_in_bin_[new_node_idx] = threshold_bin;
Guolin Ke's avatar
Guolin Ke committed
62
  threshold_[new_node_idx] = threshold_double;
63
64
65
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
66

67
68
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
69
70
                           data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
71
72
73
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
74
  SetMissingType(&decision_type_[new_node_idx], static_cast<int8_t>(missing_type));
75
76
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
77
  ++num_cat_;
78
79
80
81
82
83
84
85
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
  ++num_leaves_;
  return num_leaves_ - 1;
}

Guolin Ke's avatar
Guolin Ke committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, \
                      data_idx)                                               \
  std::vector<std::unique_ptr<BinIterator>> iter((niter));                    \
  for (int i = 0; i < (niter); ++i) {                                         \
    iter[i].reset(data->FeatureIterator((fidx_in_iter)));                     \
    iter[i]->Reset((start_pos));                                              \
  }                                                                           \
  for (data_size_t i = start; i < end; ++i) {                                 \
    int node = 0;                                                             \
    while (node >= 0) {                                                       \
      node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node,            \
                          default_bins[node], max_bins[node]);                \
    }                                                                         \
    score[(data_idx)] += static_cast<double>(leaf_value_[~node]);             \
104
105
  }\

106
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
107
108
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
109
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
124
  if (num_cat_ > 0) {
125
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
126
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
127
128
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
129
130
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
131
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
132
133
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
134
135
      });
    }
Guolin Ke's avatar
Guolin Ke committed
136
  } else {
137
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
138
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
139
140
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
141
142
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
143
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
144
145
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
146
147
      });
    }
Guolin Ke's avatar
Guolin Ke committed
148
  }
Guolin Ke's avatar
Guolin Ke committed
149
150
}

Guolin Ke's avatar
Guolin Ke committed
151
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
152
153
154
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
155
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
156
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
157
158
159
160
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
161
    return;
Guolin Ke's avatar
Guolin Ke committed
162
  }
Guolin Ke's avatar
Guolin Ke committed
163
164
165
166
167
168
169
170
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
171
  if (num_cat_ > 0) {
172
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
173
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
174
175
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
176
177
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
178
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
179
180
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
181
182
      });
    }
Guolin Ke's avatar
Guolin Ke committed
183
  } else {
184
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
185
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
186
187
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
188
189
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
190
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
191
192
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
193
194
      });
    }
Guolin Ke's avatar
Guolin Ke committed
195
  }
Guolin Ke's avatar
Guolin Ke committed
196
197
}

198
199
#undef PredictionFun

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
double Tree::GetUpperBoundValue() const {
  double upper_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] > upper_bound) {
      upper_bound = leaf_value_[i];
    }
  }
  return upper_bound;
}

double Tree::GetLowerBoundValue() const {
  double lower_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] < lower_bound) {
      lower_bound = leaf_value_[i];
    }
  }
  return lower_bound;
}

Guolin Ke's avatar
Guolin Ke committed
220
std::string Tree::ToString() const {
221
  std::stringstream str_buf;
222
223
224
225
226
227
228
229
  Common::C_stringstream(str_buf);

  #if ((defined(sun) || defined(__sun)) && (defined(__SVR4) || defined(__svr4__)))
  using CommonLegacy::ArrayToString;  // Slower & unsafe regarding locale.
  #else
  using CommonC::ArrayToString;
  #endif

230
231
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
232
  str_buf << "split_feature="
233
    << ArrayToString(split_feature_, num_leaves_ - 1) << '\n';
234
  str_buf << "split_gain="
235
    << ArrayToString(split_gain_, num_leaves_ - 1) << '\n';
236
  str_buf << "threshold="
237
    << ArrayToString<true>(threshold_, num_leaves_ - 1) << '\n';
238
  str_buf << "decision_type="
239
    << ArrayToString(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
240
  str_buf << "left_child="
241
    << ArrayToString(left_child_, num_leaves_ - 1) << '\n';
242
  str_buf << "right_child="
243
    << ArrayToString(right_child_, num_leaves_ - 1) << '\n';
244
  str_buf << "leaf_value="
245
    << ArrayToString<true>(leaf_value_, num_leaves_) << '\n';
246
  str_buf << "leaf_weight="
247
    << ArrayToString<true>(leaf_weight_, num_leaves_) << '\n';
248
  str_buf << "leaf_count="
249
    << ArrayToString(leaf_count_, num_leaves_) << '\n';
250
  str_buf << "internal_value="
251
    << ArrayToString(internal_value_, num_leaves_ - 1) << '\n';
252
  str_buf << "internal_weight="
253
    << ArrayToString(internal_weight_, num_leaves_ - 1) << '\n';
254
  str_buf << "internal_count="
255
    << ArrayToString(internal_count_, num_leaves_ - 1) << '\n';
256
257
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
258
      << ArrayToString(cat_boundaries_, num_cat_ + 1) << '\n';
259
    str_buf << "cat_threshold="
260
      << ArrayToString(cat_threshold_, cat_threshold_.size()) << '\n';
261
  }
262
263
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
264
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
265
266
}

Guolin Ke's avatar
Guolin Ke committed
267
std::string Tree::ToJSON() const {
268
  std::stringstream str_buf;
269
  Common::C_stringstream(str_buf);
270
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
271
272
273
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
274
  if (num_leaves_ == 1) {
275
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
276
  } else {
277
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
278
  }
wxchan's avatar
wxchan committed
279

280
  return str_buf.str();
wxchan's avatar
wxchan committed
281
282
}

Guolin Ke's avatar
Guolin Ke committed
283
std::string Tree::NodeToJSON(int index) const {
284
  std::stringstream str_buf;
285
  Common::C_stringstream(str_buf);
286
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
287
288
  if (index >= 0) {
    // non-leaf
289
290
291
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
Guolin Ke's avatar
Guolin Ke committed
292
    str_buf << "\"split_gain\":" << Common::AvoidInf(split_gain_[index]) << "," << '\n';
293
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
294
295
296
297
298
299
300
301
302
303
304
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
305
      str_buf << "\"threshold\":\"" << CommonC::Join(cats, "||") << "\"," << '\n';
306
      str_buf << "\"decision_type\":\"==\"," << '\n';
307
    } else {
308
309
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
310
311
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
312
      str_buf << "\"default_left\":true," << '\n';
313
    } else {
314
      str_buf << "\"default_left\":false," << '\n';
315
316
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
317
    if (missing_type == MissingType::None) {
318
      str_buf << "\"missing_type\":\"None\"," << '\n';
319
    } else if (missing_type == MissingType::Zero) {
320
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
321
    } else {
322
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
323
    }
324
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
325
    str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
326
327
328
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
329
    str_buf << "}";
wxchan's avatar
wxchan committed
330
331
332
  } else {
    // leaf
    index = ~index;
333
334
335
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
336
    str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
337
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
338
    str_buf << "}";
wxchan's avatar
wxchan committed
339
340
  }

341
  return str_buf.str();
wxchan's avatar
wxchan committed
342
343
}

Guolin Ke's avatar
Guolin Ke committed
344
345
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
346
  Common::C_stringstream(str_buf);
Guolin Ke's avatar
Guolin Ke committed
347
348
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Nikita Titov's avatar
Nikita Titov committed
349
  if (missing_type == MissingType::None
350
      || (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
351
    str_buf << "if (fval <= " << threshold_[node] << ") {";
352
  } else if (missing_type == MissingType::Zero) {
Guolin Ke's avatar
Guolin Ke committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
371
  Common::C_stringstream(str_buf);
372
  if (missing_type == MissingType::NaN) {
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
377
  int cat_idx = static_cast<int>(threshold_[node]);
378
379
380
381
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
382
383
384
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
385
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
386
  std::stringstream str_buf;
387
  Common::C_stringstream(str_buf);
388
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
389
  if (predict_leaf_index) {
390
391
392
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
393
394
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
395
  } else {
396
397
398
399
400
401
402
403
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
404
405
406
407
408
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
409
    str_buf << NodeToIfElse(0, predict_leaf_index);
410
  }
411
  str_buf << " }" << '\n';
412

413
  // Predict func by Map to ifelse
414
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
415
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
416
417
418
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
419
420
421
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
437
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
438
  }
439
  str_buf << " }" << '\n';
440

441
442
443
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
444
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
445
  std::stringstream str_buf;
446
  Common::C_stringstream(str_buf);
447
448
449
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
450
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
451
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
452
      str_buf << NumericalDecisionIfElse(index);
453
    } else {
454
      str_buf << CategoricalDecisionIfElse(index);
455
456
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
457
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
458
459
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
460
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
461
462
463
464
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
465
    if (predict_leaf_index) {
466
467
468
469
470
471
472
473
474
475
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
476
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
477
  std::stringstream str_buf;
478
  Common::C_stringstream(str_buf);
479
480
481
482
483
484
485
486
487
488
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
489
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
490
491
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
492
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
493
494
495
496
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
497
    if (predict_leaf_index) {
498
499
500
501
502
503
504
505
506
507
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

508
509
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
510
  std::unordered_map<std::string, std::string> key_vals;
511
  const int max_num_line = 17;
512
513
514
515
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
516
    while (*p != '=') ++p;
517
518
519
520
521
522
523
524
525
526
527
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

528
  if (key_vals.count("num_leaves") <= 0) {
529
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
530
531
532
533
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

534
  if (key_vals.count("num_cat") <= 0) {
535
    Log::Fatal("Tree model should contain num_cat field");
536
537
538
539
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

540
  if (key_vals.count("leaf_value")) {
541
    leaf_value_ = CommonC::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
542
543
544
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }
545

Guolin Ke's avatar
Guolin Ke committed
546
  if (key_vals.count("shrinkage")) {
547
    CommonC::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
Guolin Ke's avatar
Guolin Ke committed
548
549
550
  } else {
    shrinkage_ = 1.0f;
  }
551

552
553
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
554
  if (key_vals.count("left_child")) {
555
    left_child_ = CommonC::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
556
557
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
558
559
  }

Guolin Ke's avatar
Guolin Ke committed
560
  if (key_vals.count("right_child")) {
561
    right_child_ = CommonC::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
562
563
564
565
566
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
567
    split_feature_ = CommonC::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
568
569
570
571
572
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
573
    threshold_ = CommonC::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
574
575
576
577
578
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
579
    split_gain_ = CommonC::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
580
581
582
583
584
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
585
    internal_count_ = CommonC::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
586
587
588
589
590
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
591
    internal_value_ = CommonC::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
592
593
594
595
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

596
  if (key_vals.count("internal_weight")) {
597
    internal_weight_ = CommonC::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
598
  } else {
599
600
601
602
    internal_weight_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_weight")) {
603
    leaf_weight_ = CommonC::StringToArray<double>(key_vals["leaf_weight"], num_leaves_);
604
  } else {
605
606
607
    leaf_weight_.resize(num_leaves_);
  }

Guolin Ke's avatar
Guolin Ke committed
608
  if (key_vals.count("leaf_count")) {
609
    leaf_count_ = CommonC::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
610
611
612
613
614
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
615
    decision_type_ = CommonC::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
616
617
618
619
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

620
621
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
622
      cat_boundaries_ = CommonC::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
623
624
625
626
627
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
628
      cat_threshold_ = CommonC::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
629
    } else {
630
      Log::Fatal("Tree model should contain cat_threshold field");
631
632
    }
  }
633
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
634
635
}

Guolin Ke's avatar
Guolin Ke committed
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
701
  PathElement* unique_path = parent_unique_path + unique_depth;
702
703
704
  if (unique_depth > 0) {
    std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  }
Guolin Ke's avatar
Guolin Ke committed
705
706
707
708
709
710
711
712
713
714
715
716
717
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
718
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
719
720
721
722
723
724
725
726
727
728
729
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
730
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
731
732
733
734
735
736
737
738
739
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
740
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
741
742

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
743
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
744
745
746
  }
}

747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
// recursive sparse computation of SHAP values for a decision tree
void Tree::TreeSHAPByMap(const std::unordered_map<int, double>& feature_values, std::unordered_map<int, double>* phi,
                         int node, int unique_depth,
                         PathElement *parent_unique_path, double parent_zero_fraction,
                         double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
  PathElement* unique_path = parent_unique_path + unique_depth;
  if (unique_depth > 0) {
    std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  }
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      (*phi)[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

  // internal node
  } else {
    const int hot_index = Decision(feature_values.count(split_feature_[node]) > 0 ? feature_values.at(split_feature_[node]) : 0.0f, node);
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAPByMap(feature_values, phi, hot_index, unique_depth + 1, unique_path,
                  hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);

    TreeSHAPByMap(feature_values, phi, cold_index, unique_depth + 1, unique_path,
                  cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
  }
}

799
800
801
802
803
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
804
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
805
  }
806
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
807
808
}

809
810
811
812
813
814
815
816
817
818
819
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
820
821
822
  }
}

Guolin Ke's avatar
Guolin Ke committed
823
}  // namespace LightGBM