tree.cpp 26.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
8
9
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
10
11

#include <functional>
12
#include <iomanip>
13
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
33
34
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
35
  num_leaves_ = 1;
36
  leaf_value_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
37
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
38
  shrinkage_ = 1.0f;
39
  num_cat_ = 0;
40
41
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
42
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
43
}
Guolin Ke's avatar
Guolin Ke committed
44

Guolin Ke's avatar
Guolin Ke committed
45
46
47
Tree::~Tree() {
}

48
49
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
50
                int left_cnt, int right_cnt, float gain, MissingType missing_type, bool default_left) {
51
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
Guolin Ke's avatar
Guolin Ke committed
52
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
53
  decision_type_[new_node_idx] = 0;
54
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
55
56
57
58
59
60
61
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
62
  }
Guolin Ke's avatar
Guolin Ke committed
63
  threshold_in_bin_[new_node_idx] = threshold_bin;
64
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
65
66
67
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
68

69
70
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
71
                           data_size_t left_cnt, data_size_t right_cnt, float gain, MissingType missing_type) {
72
73
74
75
76
77
78
79
80
81
82
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, gain);
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
83
84
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
85
  ++num_cat_;
86
87
88
89
90
91
92
93
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
94
95
96
97
  ++num_leaves_;
  return num_leaves_ - 1;
}

98
99
100
101
102
103
104
105
106
107
108
109
110
111
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

112
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
113
114
115
116
117
118
119
120
121
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
127
128
129
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
130
  if (num_cat_ > 0) {
131
    if (data->num_features() > num_leaves_ - 1) {
132
133
134
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
135
136
      });
    } else {
137
138
139
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
140
141
      });
    }
Guolin Ke's avatar
Guolin Ke committed
142
  } else {
143
    if (data->num_features() > num_leaves_ - 1) {
144
145
146
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
147
148
      });
    } else {
149
150
151
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
152
153
      });
    }
Guolin Ke's avatar
Guolin Ke committed
154
  }
Guolin Ke's avatar
Guolin Ke committed
155
156
}

Guolin Ke's avatar
Guolin Ke committed
157
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
158
159
160
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
165
166
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
167
    return;
Guolin Ke's avatar
Guolin Ke committed
168
  }
Guolin Ke's avatar
Guolin Ke committed
169
170
171
172
173
174
175
176
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
177
  if (num_cat_ > 0) {
178
    if (data->num_features() > num_leaves_ - 1) {
179
180
181
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
182
183
      });
    } else {
184
185
186
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
187
188
      });
    }
Guolin Ke's avatar
Guolin Ke committed
189
  } else {
190
    if (data->num_features() > num_leaves_ - 1) {
191
192
193
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
194
195
      });
    } else {
196
197
198
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
199
200
      });
    }
Guolin Ke's avatar
Guolin Ke committed
201
  }
Guolin Ke's avatar
Guolin Ke committed
202
203
}

204
205
#undef PredictionFun

Guolin Ke's avatar
Guolin Ke committed
206
std::string Tree::ToString() const {
207
  std::stringstream str_buf;
208
209
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
210
  str_buf << "split_feature="
211
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
212
  str_buf << "split_gain="
213
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
214
  str_buf << "threshold="
215
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
216
  str_buf << "decision_type="
217
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
218
  str_buf << "left_child="
219
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
220
  str_buf << "right_child="
221
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
222
  str_buf << "leaf_value="
223
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
224
  str_buf << "leaf_count="
225
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
226
  str_buf << "internal_value="
227
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
228
  str_buf << "internal_count="
229
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
230
231
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
232
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
233
    str_buf << "cat_threshold="
234
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
235
  }
236
237
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
238
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
239
240
}

Guolin Ke's avatar
Guolin Ke committed
241
std::string Tree::ToJSON() const {
242
  std::stringstream str_buf;
243
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
244
245
246
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
247
  if (num_leaves_ == 1) {
248
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
249
  } else {
250
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
251
  }
wxchan's avatar
wxchan committed
252

253
  return str_buf.str();
wxchan's avatar
wxchan committed
254
255
}

Guolin Ke's avatar
Guolin Ke committed
256
std::string Tree::NodeToJSON(int index) const {
257
  std::stringstream str_buf;
258
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
259
260
  if (index >= 0) {
    // non-leaf
261
262
263
264
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
265
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
266
267
268
269
270
271
272
273
274
275
276
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
277
278
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
279
    } else {
280
281
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
282
283
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
284
      str_buf << "\"default_left\":true," << '\n';
285
    } else {
286
      str_buf << "\"default_left\":false," << '\n';
287
288
289
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
290
      str_buf << "\"missing_type\":\"None\"," << '\n';
291
    } else if (missing_type == 1) {
292
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
293
    } else {
294
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
295
    }
296
297
298
299
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
300
    str_buf << "}";
wxchan's avatar
wxchan committed
301
302
303
  } else {
    // leaf
    index = ~index;
304
305
306
307
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
308
    str_buf << "}";
wxchan's avatar
wxchan committed
309
310
  }

311
  return str_buf.str();
wxchan's avatar
wxchan committed
312
313
}

Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
318
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
344
  int cat_idx = static_cast<int>(threshold_[node]);
345
346
347
348
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
349
350
351
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
352
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
353
354
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
355
  if (predict_leaf_index) {
356
357
358
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
359
360
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
361
  } else {
362
363
364
365
366
367
368
369
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
370
371
372
373
374
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
375
    str_buf << NodeToIfElse(0, predict_leaf_index);
376
  }
377
  str_buf << " }" << '\n';
378

379
  // Predict func by Map to ifelse
380
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
381
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
382
383
384
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
385
386
387
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
403
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
404
  }
405
  str_buf << " }" << '\n';
406

407
408
409
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
410
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
411
412
413
414
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
415
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
416
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
417
      str_buf << NumericalDecisionIfElse(index);
418
    } else {
419
      str_buf << CategoricalDecisionIfElse(index);
420
421
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
422
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
423
424
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
425
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
426
427
428
429
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
430
    if (predict_leaf_index) {
431
432
433
434
435
436
437
438
439
440
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
441
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
442
443
444
445
446
447
448
449
450
451
452
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
453
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
454
455
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
456
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
457
458
459
460
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
461
    if (predict_leaf_index) {
462
463
464
465
466
467
468
469
470
471
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

472
473
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
474
  std::unordered_map<std::string, std::string> key_vals;
475
476
477
478
479
  const int max_num_line = 15;
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
480
    while (*p != '=') ++p;
481
482
483
484
485
486
487
488
489
490
491
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

492
  if (key_vals.count("num_leaves") <= 0) {
493
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
494
495
496
497
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

498
  if (key_vals.count("num_cat") <= 0) {
499
    Log::Fatal("Tree model should contain num_cat field");
500
501
502
503
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

504
  if (key_vals.count("leaf_value")) {
505
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
506
507
508
509
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }

510
511
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
512
  if (key_vals.count("left_child")) {
513
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
514
515
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
516
517
  }

Guolin Ke's avatar
Guolin Ke committed
518
  if (key_vals.count("right_child")) {
519
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
520
521
522
523
524
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
525
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
526
527
528
529
530
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
531
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
532
533
534
535
536
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
537
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
538
539
540
541
542
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
543
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
544
545
546
547
548
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
549
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
550
551
552
553
554
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_count")) {
555
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
556
557
558
559
560
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
561
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
562
563
564
565
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

566
567
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
568
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
569
570
571
572
573
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
574
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
575
    } else {
576
      Log::Fatal("Tree model should contain cat_threshold field");
577
578
579
    }
  }

Guolin Ke's avatar
Guolin Ke committed
580
581
582
583
584
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
585
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
586
587
}

Guolin Ke's avatar
Guolin Ke committed
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
653
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
654
655
656
657
658
659
660
661
662
663
664
665
666
667
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
668
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
669
670
671
672
673
674
675
676
677
678
679
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
680
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
681
682
683
684
685
686
687
688
689
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
690
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
691
692

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
693
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
694
695
696
  }
}

697
698
699
700
701
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
702
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
703
  }
704
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
705
706
}

707
708
709
710
711
712
713
714
715
716
717
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
718
719
720
  }
}

Guolin Ke's avatar
Guolin Ke committed
721
}  // namespace LightGBM