tree.cpp 27.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
8
9
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
10
11

#include <functional>
12
#include <iomanip>
13
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
29
  leaf_weight_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
30
31
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
32
  internal_weight_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
33
34
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
  num_leaves_ = 1;
38
  leaf_value_[0] = 0.0f;
39
  leaf_weight_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
40
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
41
  shrinkage_ = 1.0f;
42
  num_cat_ = 0;
43
44
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
45
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
46
}
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
49
50
Tree::~Tree() {
}

51
52
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
53
54
                int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type, bool default_left) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
Guolin Ke's avatar
Guolin Ke committed
55
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
56
  decision_type_[new_node_idx] = 0;
57
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
65
  }
Guolin Ke's avatar
Guolin Ke committed
66
  threshold_in_bin_[new_node_idx] = threshold_bin;
67
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
68
69
70
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
71

72
73
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
74
75
                           data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
76
77
78
79
80
81
82
83
84
85
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
86
87
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
88
  ++num_cat_;
89
90
91
92
93
94
95
96
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
  ++num_leaves_;
  return num_leaves_ - 1;
}

101
102
103
104
105
106
107
108
109
110
111
112
113
114
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, data_idx) \
std::vector<std::unique_ptr<BinIterator>> iter((niter)); \
for (int i = 0; i < (niter); ++i) { \
  iter[i].reset(data->FeatureIterator((fidx_in_iter))); \
  iter[i]->Reset((start_pos)); \
}\
for (data_size_t i = start; i < end; ++i) {\
  int node = 0;\
  while (node >= 0) {\
    node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node, default_bins[node], max_bins[node]);\
  }\
  score[(data_idx)] += static_cast<double>(leaf_value_[~node]);\
}\

115
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
132
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
133
  if (num_cat_ > 0) {
134
    if (data->num_features() > num_leaves_ - 1) {
135
136
137
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
138
139
      });
    } else {
140
141
142
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
143
144
      });
    }
Guolin Ke's avatar
Guolin Ke committed
145
  } else {
146
    if (data->num_features() > num_leaves_ - 1) {
147
148
149
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
150
151
      });
    } else {
152
153
154
      Threading::For<data_size_t>(0, num_data, [this, &data, score, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
155
156
      });
    }
Guolin Ke's avatar
Guolin Ke committed
157
  }
Guolin Ke's avatar
Guolin Ke committed
158
159
}

Guolin Ke's avatar
Guolin Ke committed
160
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
161
162
163
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
164
165
166
167
168
169
    if (leaf_value_[0] != 0.0f) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
170
    return;
Guolin Ke's avatar
Guolin Ke committed
171
  }
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
180
  if (num_cat_ > 0) {
181
    if (data->num_features() > num_leaves_ - 1) {
182
183
184
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
185
186
      });
    } else {
187
188
189
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
190
191
      });
    }
Guolin Ke's avatar
Guolin Ke committed
192
  } else {
193
    if (data->num_features() > num_leaves_ - 1) {
194
195
196
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
197
198
      });
    } else {
199
200
201
      Threading::For<data_size_t>(0, num_data, [this, &data, score, used_data_indices, &default_bins, &max_bins]
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
202
203
      });
    }
Guolin Ke's avatar
Guolin Ke committed
204
  }
Guolin Ke's avatar
Guolin Ke committed
205
206
}

207
208
#undef PredictionFun

Guolin Ke's avatar
Guolin Ke committed
209
std::string Tree::ToString() const {
210
  std::stringstream str_buf;
211
212
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
213
  str_buf << "split_feature="
214
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
215
  str_buf << "split_gain="
216
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
217
  str_buf << "threshold="
218
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
219
  str_buf << "decision_type="
220
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
221
  str_buf << "left_child="
222
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
223
  str_buf << "right_child="
224
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
225
  str_buf << "leaf_value="
226
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
227
228
  str_buf << "leaf_weight="
    << Common::ArrayToString(leaf_weight_, num_leaves_) << '\n';
229
  str_buf << "leaf_count="
230
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
231
  str_buf << "internal_value="
232
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
233
234
  str_buf << "internal_weight="
    << Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n';
235
  str_buf << "internal_count="
236
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
237
238
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
239
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
240
    str_buf << "cat_threshold="
241
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
242
  }
243
244
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
245
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
246
247
}

Guolin Ke's avatar
Guolin Ke committed
248
std::string Tree::ToJSON() const {
249
  std::stringstream str_buf;
250
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
251
252
253
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
254
  if (num_leaves_ == 1) {
255
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
256
  } else {
257
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
258
  }
wxchan's avatar
wxchan committed
259

260
  return str_buf.str();
wxchan's avatar
wxchan committed
261
262
}

Guolin Ke's avatar
Guolin Ke committed
263
std::string Tree::NodeToJSON(int index) const {
264
  std::stringstream str_buf;
265
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
266
267
  if (index >= 0) {
    // non-leaf
268
269
270
271
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n';
272
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
273
274
275
276
277
278
279
280
281
282
283
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
284
285
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
286
    } else {
287
288
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
289
290
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
291
      str_buf << "\"default_left\":true," << '\n';
292
    } else {
293
      str_buf << "\"default_left\":false," << '\n';
294
295
296
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
297
      str_buf << "\"missing_type\":\"None\"," << '\n';
298
    } else if (missing_type == 1) {
299
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
300
    } else {
301
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
302
    }
303
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
304
    str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
305
306
307
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
308
    str_buf << "}";
wxchan's avatar
wxchan committed
309
310
311
  } else {
    // leaf
    index = ~index;
312
313
314
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
315
    str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
316
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
317
    str_buf << "}";
wxchan's avatar
wxchan committed
318
319
  }

320
  return str_buf.str();
wxchan's avatar
wxchan committed
321
322
}

Guolin Ke's avatar
Guolin Ke committed
323
324
325
326
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
327
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
353
  int cat_idx = static_cast<int>(threshold_[node]);
354
355
356
357
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
358
359
360
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
361
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
362
363
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
364
  if (predict_leaf_index) {
365
366
367
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
368
369
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
370
  } else {
371
372
373
374
375
376
377
378
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
379
380
381
382
383
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
384
    str_buf << NodeToIfElse(0, predict_leaf_index);
385
  }
386
  str_buf << " }" << '\n';
387

388
  // Predict func by Map to ifelse
389
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
390
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
391
392
393
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
394
395
396
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
412
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
413
  }
414
  str_buf << " }" << '\n';
415

416
417
418
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
419
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
420
421
422
423
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
424
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
425
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
426
      str_buf << NumericalDecisionIfElse(index);
427
    } else {
428
      str_buf << CategoricalDecisionIfElse(index);
429
430
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
431
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
432
433
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
434
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
435
436
437
438
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
439
    if (predict_leaf_index) {
440
441
442
443
444
445
446
447
448
449
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
450
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
451
452
453
454
455
456
457
458
459
460
461
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
462
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
463
464
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
465
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
466
467
468
469
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
470
    if (predict_leaf_index) {
471
472
473
474
475
476
477
478
479
480
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

481
482
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
483
  std::unordered_map<std::string, std::string> key_vals;
484
  const int max_num_line = 17;
485
486
487
488
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
489
    while (*p != '=') ++p;
490
491
492
493
494
495
496
497
498
499
500
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

501
  if (key_vals.count("num_leaves") <= 0) {
502
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
503
504
505
506
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

507
  if (key_vals.count("num_cat") <= 0) {
508
    Log::Fatal("Tree model should contain num_cat field");
509
510
511
512
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

513
  if (key_vals.count("leaf_value")) {
514
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
515
516
517
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }
Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
522
523
  
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
524

525
526
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
527
  if (key_vals.count("left_child")) {
528
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
529
530
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
531
532
  }

Guolin Ke's avatar
Guolin Ke committed
533
  if (key_vals.count("right_child")) {
534
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
535
536
537
538
539
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
540
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
541
542
543
544
545
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
546
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
547
548
549
550
551
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
552
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
553
554
555
556
557
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
558
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
559
560
561
562
563
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
564
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
565
566
567
568
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

569
570
571
572
573
574
575
576
577
578
579
580
581
582
  if (key_vals.count("internal_weight")) {
    internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
  }
  else {
    internal_weight_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_weight")) {
    leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
  }
  else {
    leaf_weight_.resize(num_leaves_);
  }

Guolin Ke's avatar
Guolin Ke committed
583
  if (key_vals.count("leaf_count")) {
584
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
585
586
587
588
589
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
590
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
591
592
593
594
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

595
596
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
597
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
598
599
600
601
602
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
603
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
604
    } else {
605
      Log::Fatal("Tree model should contain cat_threshold field");
606
607
    }
  }
608
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
609
610
}

Guolin Ke's avatar
Guolin Ke committed
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
676
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
677
678
679
680
681
682
683
684
685
686
687
688
689
690
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
691
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
692
693
694
695
696
697
698
699
700
701
702
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
703
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
704
705
706
707
708
709
710
711
712
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
713
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
714
715

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
716
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
717
718
719
  }
}

720
721
722
723
724
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
725
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
726
  }
727
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
728
729
}

730
731
732
733
734
735
736
737
738
739
740
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
741
742
743
  }
}

Guolin Ke's avatar
Guolin Ke committed
744
}  // namespace LightGBM