tree.cpp 29 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
8
9
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
10
11

#include <functional>
12
#include <iomanip>
13
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
29
  leaf_weight_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
30
31
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
32
  internal_weight_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
33
34
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
  num_leaves_ = 1;
38
  leaf_value_[0] = 0.0f;
39
  leaf_weight_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
40
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
41
  shrinkage_ = 1.0f;
42
  num_cat_ = 0;
43
44
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
45
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
46
}
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
49
50
Tree::~Tree() {
}

51
52
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
53
54
                int left_cnt, int right_cnt, double left_weight, double right_weight, float gain,
                MissingType missing_type, bool default_left) {
55
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
Guolin Ke's avatar
Guolin Ke committed
56
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
57
  decision_type_[new_node_idx] = 0;
58
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
63
64
65
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
66
  }
Guolin Ke's avatar
Guolin Ke committed
67
  threshold_in_bin_[new_node_idx] = threshold_bin;
Guolin Ke's avatar
Guolin Ke committed
68
  threshold_[new_node_idx] = threshold_double;
69
70
71
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
72

73
74
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
75
76
                           data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
77
78
79
80
81
82
83
84
85
86
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
87
88
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
89
  ++num_cat_;
90
91
92
93
94
95
96
97
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
98
99
100
101
  ++num_leaves_;
  return num_leaves_ - 1;
}

Guolin Ke's avatar
Guolin Ke committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, \
                      data_idx)                                               \
  std::vector<std::unique_ptr<BinIterator>> iter((niter));                    \
  for (int i = 0; i < (niter); ++i) {                                         \
    iter[i].reset(data->FeatureIterator((fidx_in_iter)));                     \
    iter[i]->Reset((start_pos));                                              \
  }                                                                           \
  for (data_size_t i = start; i < end; ++i) {                                 \
    int node = 0;                                                             \
    while (node >= 0) {                                                       \
      node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node,            \
                          default_bins[node], max_bins[node]);                \
    }                                                                         \
    score[(data_idx)] += static_cast<double>(leaf_value_[~node]);             \
116
117
  }\

118
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
119
120
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
121
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
127
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
128
129
130
131
132
133
134
135
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
136
  if (num_cat_ > 0) {
137
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
138
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
139
140
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
141
142
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
143
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
144
145
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
146
147
      });
    }
Guolin Ke's avatar
Guolin Ke committed
148
  } else {
149
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
150
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
151
152
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
153
154
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
155
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
156
157
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
158
159
      });
    }
Guolin Ke's avatar
Guolin Ke committed
160
  }
Guolin Ke's avatar
Guolin Ke committed
161
162
}

Guolin Ke's avatar
Guolin Ke committed
163
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
164
165
166
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
167
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
168
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
169
170
171
172
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
173
    return;
Guolin Ke's avatar
Guolin Ke committed
174
  }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
178
179
180
181
182
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
183
  if (num_cat_ > 0) {
184
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
185
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
186
187
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
188
189
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
190
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
191
192
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
193
194
      });
    }
Guolin Ke's avatar
Guolin Ke committed
195
  } else {
196
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
197
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
198
199
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
200
201
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
202
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
203
204
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
205
206
      });
    }
Guolin Ke's avatar
Guolin Ke committed
207
  }
Guolin Ke's avatar
Guolin Ke committed
208
209
}

210
211
#undef PredictionFun

212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
double Tree::GetUpperBoundValue() const {
  double upper_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] > upper_bound) {
      upper_bound = leaf_value_[i];
    }
  }
  return upper_bound;
}

double Tree::GetLowerBoundValue() const {
  double lower_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] < lower_bound) {
      lower_bound = leaf_value_[i];
    }
  }
  return lower_bound;
}

Guolin Ke's avatar
Guolin Ke committed
232
std::string Tree::ToString() const {
233
  std::stringstream str_buf;
234
235
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
236
  str_buf << "split_feature="
237
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
238
  str_buf << "split_gain="
239
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
240
  str_buf << "threshold="
241
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
242
  str_buf << "decision_type="
243
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
244
  str_buf << "left_child="
245
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
246
  str_buf << "right_child="
247
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
248
  str_buf << "leaf_value="
249
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
250
251
  str_buf << "leaf_weight="
    << Common::ArrayToString(leaf_weight_, num_leaves_) << '\n';
252
  str_buf << "leaf_count="
253
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
254
  str_buf << "internal_value="
255
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
256
257
  str_buf << "internal_weight="
    << Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n';
258
  str_buf << "internal_count="
259
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
260
261
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
262
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
263
    str_buf << "cat_threshold="
264
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
265
  }
266
267
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
268
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
269
270
}

Guolin Ke's avatar
Guolin Ke committed
271
std::string Tree::ToJSON() const {
272
  std::stringstream str_buf;
273
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
274
275
276
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
277
  if (num_leaves_ == 1) {
278
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
279
  } else {
280
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
281
  }
wxchan's avatar
wxchan committed
282

283
  return str_buf.str();
wxchan's avatar
wxchan committed
284
285
}

Guolin Ke's avatar
Guolin Ke committed
286
std::string Tree::NodeToJSON(int index) const {
287
  std::stringstream str_buf;
288
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
289
290
  if (index >= 0) {
    // non-leaf
291
292
293
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
Guolin Ke's avatar
Guolin Ke committed
294
    str_buf << "\"split_gain\":" << Common::AvoidInf(split_gain_[index]) << "," << '\n';
295
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
296
297
298
299
300
301
302
303
304
305
306
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
307
308
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
309
    } else {
310
311
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
312
313
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
314
      str_buf << "\"default_left\":true," << '\n';
315
    } else {
316
      str_buf << "\"default_left\":false," << '\n';
317
318
319
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
320
      str_buf << "\"missing_type\":\"None\"," << '\n';
321
    } else if (missing_type == 1) {
322
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
323
    } else {
324
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
325
    }
326
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
327
    str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
328
329
330
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
331
    str_buf << "}";
wxchan's avatar
wxchan committed
332
333
334
  } else {
    // leaf
    index = ~index;
335
336
337
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
338
    str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
339
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
340
    str_buf << "}";
wxchan's avatar
wxchan committed
341
342
  }

343
  return str_buf.str();
wxchan's avatar
wxchan committed
344
345
}

Guolin Ke's avatar
Guolin Ke committed
346
347
348
349
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
350
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
376
  int cat_idx = static_cast<int>(threshold_[node]);
377
378
379
380
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
381
382
383
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
384
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
385
386
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
387
  if (predict_leaf_index) {
388
389
390
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
391
392
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
393
  } else {
394
395
396
397
398
399
400
401
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
402
403
404
405
406
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
407
    str_buf << NodeToIfElse(0, predict_leaf_index);
408
  }
409
  str_buf << " }" << '\n';
410

411
  // Predict func by Map to ifelse
412
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
413
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
414
415
416
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
417
418
419
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
435
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
436
  }
437
  str_buf << " }" << '\n';
438

439
440
441
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
442
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
443
444
445
446
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
447
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
448
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
449
      str_buf << NumericalDecisionIfElse(index);
450
    } else {
451
      str_buf << CategoricalDecisionIfElse(index);
452
453
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
454
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
455
456
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
457
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
458
459
460
461
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
462
    if (predict_leaf_index) {
463
464
465
466
467
468
469
470
471
472
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
473
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
474
475
476
477
478
479
480
481
482
483
484
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
485
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
486
487
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
488
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
489
490
491
492
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
493
    if (predict_leaf_index) {
494
495
496
497
498
499
500
501
502
503
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

504
505
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
506
  std::unordered_map<std::string, std::string> key_vals;
507
  const int max_num_line = 17;
508
509
510
511
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
512
    while (*p != '=') ++p;
513
514
515
516
517
518
519
520
521
522
523
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

524
  if (key_vals.count("num_leaves") <= 0) {
525
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
526
527
528
529
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

530
  if (key_vals.count("num_cat") <= 0) {
531
    Log::Fatal("Tree model should contain num_cat field");
532
533
534
535
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

536
  if (key_vals.count("leaf_value")) {
537
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
538
539
540
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }
541

Guolin Ke's avatar
Guolin Ke committed
542
543
544
545
546
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
547

548
549
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
550
  if (key_vals.count("left_child")) {
551
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
552
553
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
554
555
  }

Guolin Ke's avatar
Guolin Ke committed
556
  if (key_vals.count("right_child")) {
557
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
558
559
560
561
562
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
563
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
564
565
566
567
568
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
569
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
570
571
572
573
574
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
575
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
576
577
578
579
580
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
581
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
582
583
584
585
586
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
587
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
588
589
590
591
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

592
593
  if (key_vals.count("internal_weight")) {
    internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
594
  } else {
595
596
597
598
599
    internal_weight_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_weight")) {
    leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
600
  } else {
601
602
603
    leaf_weight_.resize(num_leaves_);
  }

Guolin Ke's avatar
Guolin Ke committed
604
  if (key_vals.count("leaf_count")) {
605
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
606
607
608
609
610
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
611
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
612
613
614
615
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

616
617
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
618
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
619
620
621
622
623
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
624
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
625
    } else {
626
      Log::Fatal("Tree model should contain cat_threshold field");
627
628
    }
  }
629
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
630
631
}

Guolin Ke's avatar
Guolin Ke committed
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
697
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
698
699
700
701
702
703
704
705
706
707
708
709
710
711
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
712
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
713
714
715
716
717
718
719
720
721
722
723
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
724
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
725
726
727
728
729
730
731
732
733
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
734
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
735
736

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
737
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
738
739
740
  }
}

741
742
743
744
745
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
746
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
747
  }
748
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
749
750
}

751
752
753
754
755
756
757
758
759
760
761
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
762
763
764
  }
}

Guolin Ke's avatar
Guolin Ke committed
765
}  // namespace LightGBM