tree.cpp 29 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#include <LightGBM/tree.h>

#include <LightGBM/dataset.h>
8
9
#include <LightGBM/utils/common.h>
#include <LightGBM/utils/threading.h>
Guolin Ke's avatar
Guolin Ke committed
10
11

#include <functional>
12
#include <iomanip>
13
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {
Guolin Ke's avatar
Guolin Ke committed
19
20
21
22
23
24
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  decision_type_.resize(max_leaves_ - 1, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
29
  leaf_weight_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
30
31
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
32
  internal_weight_.resize(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
33
34
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
  num_leaves_ = 1;
38
  leaf_value_[0] = 0.0f;
39
  leaf_weight_[0] = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
40
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
41
  shrinkage_ = 1.0f;
42
  num_cat_ = 0;
43
44
  cat_boundaries_.push_back(0);
  cat_boundaries_inner_.push_back(0);
45
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
46
}
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
49
50
Tree::~Tree() {
}

51
52
int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
                double threshold_double, double left_value, double right_value,
53
54
                int left_cnt, int right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type, bool default_left) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
Guolin Ke's avatar
Guolin Ke committed
55
  int new_node_idx = num_leaves_ - 1;
Guolin Ke's avatar
Guolin Ke committed
56
  decision_type_[new_node_idx] = 0;
57
  SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
  SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
65
  }
Guolin Ke's avatar
Guolin Ke committed
66
  threshold_in_bin_[new_node_idx] = threshold_bin;
Guolin Ke's avatar
Guolin Ke committed
67
  threshold_[new_node_idx] = threshold_double;
68
69
70
  ++num_leaves_;
  return num_leaves_ - 1;
}
Guolin Ke's avatar
Guolin Ke committed
71

72
73
int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32_t* threshold_bin, int num_threshold_bin,
                           const uint32_t* threshold, int num_threshold, double left_value, double right_value,
74
75
                           data_size_t left_cnt, data_size_t right_cnt, double left_weight, double right_weight, float gain, MissingType missing_type) {
  Split(leaf, feature, real_feature, left_value, right_value, left_cnt, right_cnt, left_weight, right_weight, gain);
76
77
78
79
80
81
82
83
84
85
  int new_node_idx = num_leaves_ - 1;
  decision_type_[new_node_idx] = 0;
  SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
  if (missing_type == MissingType::None) {
    SetMissingType(&decision_type_[new_node_idx], 0);
  } else if (missing_type == MissingType::Zero) {
    SetMissingType(&decision_type_[new_node_idx], 1);
  } else if (missing_type == MissingType::NaN) {
    SetMissingType(&decision_type_[new_node_idx], 2);
  }
86
87
  threshold_in_bin_[new_node_idx] = num_cat_;
  threshold_[new_node_idx] = num_cat_;
88
  ++num_cat_;
89
90
91
92
93
94
95
96
  cat_boundaries_.push_back(cat_boundaries_.back() + num_threshold);
  for (int i = 0; i < num_threshold; ++i) {
    cat_threshold_.push_back(threshold[i]);
  }
  cat_boundaries_inner_.push_back(cat_boundaries_inner_.back() + num_threshold_bin);
  for (int i = 0; i < num_threshold_bin; ++i) {
    cat_threshold_inner_.push_back(threshold_bin[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
97
98
99
100
  ++num_leaves_;
  return num_leaves_ - 1;
}

Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#define PredictionFun(niter, fidx_in_iter, start_pos, decision_fun, iter_idx, \
                      data_idx)                                               \
  std::vector<std::unique_ptr<BinIterator>> iter((niter));                    \
  for (int i = 0; i < (niter); ++i) {                                         \
    iter[i].reset(data->FeatureIterator((fidx_in_iter)));                     \
    iter[i]->Reset((start_pos));                                              \
  }                                                                           \
  for (data_size_t i = start; i < end; ++i) {                                 \
    int node = 0;                                                             \
    while (node >= 0) {                                                       \
      node = decision_fun(iter[(iter_idx)]->Get((data_idx)), node,            \
                          default_bins[node], max_bins[node]);                \
    }                                                                         \
    score[(data_idx)] += static_cast<double>(leaf_value_[~node]);             \
115
116
  }\

117
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
118
119
  if (num_leaves_ <= 1) {
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
120
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
126
      for (data_size_t i = 0; i < num_data; ++i) {
        score[i] += leaf_value_[0];
      }
    }
    return;
  }
Guolin Ke's avatar
Guolin Ke committed
127
128
129
130
131
132
133
134
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
135
  if (num_cat_ > 0) {
136
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
137
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
138
139
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, DecisionInner, node, i);
140
141
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
142
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
143
144
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, DecisionInner, split_feature_inner_[node], i);
145
146
      });
    }
Guolin Ke's avatar
Guolin Ke committed
147
  } else {
148
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
149
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
150
151
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], start, NumericalDecisionInner, node, i);
152
153
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
154
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, &default_bins, &max_bins]
155
156
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, start, NumericalDecisionInner, split_feature_inner_[node], i);
157
158
      });
    }
Guolin Ke's avatar
Guolin Ke committed
159
  }
Guolin Ke's avatar
Guolin Ke committed
160
161
}

Guolin Ke's avatar
Guolin Ke committed
162
void Tree::AddPredictionToScore(const Dataset* data,
Guolin Ke's avatar
Guolin Ke committed
163
164
165
                                const data_size_t* used_data_indices,
                                data_size_t num_data, double* score) const {
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
166
    if (leaf_value_[0] != 0.0f) {
Guolin Ke's avatar
Guolin Ke committed
167
#pragma omp parallel for schedule(static, 512) if (num_data >= 1024)
Guolin Ke's avatar
Guolin Ke committed
168
169
170
171
      for (data_size_t i = 0; i < num_data; ++i) {
        score[used_data_indices[i]] += leaf_value_[0];
      }
    }
Guolin Ke's avatar
Guolin Ke committed
172
    return;
Guolin Ke's avatar
Guolin Ke committed
173
  }
Guolin Ke's avatar
Guolin Ke committed
174
175
176
177
178
179
180
181
  std::vector<uint32_t> default_bins(num_leaves_ - 1);
  std::vector<uint32_t> max_bins(num_leaves_ - 1);
  for (int i = 0; i < num_leaves_ - 1; ++i) {
    const int fidx = split_feature_inner_[i];
    auto bin_mapper = data->FeatureBinMapper(fidx);
    default_bins[i] = bin_mapper->GetDefaultBin();
    max_bins[i] = bin_mapper->num_bin() - 1;
  }
182
  if (num_cat_ > 0) {
183
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
184
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
185
186
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], DecisionInner, node, used_data_indices[i]);
187
188
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
189
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
190
191
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], DecisionInner, split_feature_inner_[node], used_data_indices[i]);
192
193
      });
    }
Guolin Ke's avatar
Guolin Ke committed
194
  } else {
195
    if (data->num_features() > num_leaves_ - 1) {
Guolin Ke's avatar
Guolin Ke committed
196
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
197
198
      (int, data_size_t start, data_size_t end) {
        PredictionFun(num_leaves_ - 1, split_feature_inner_[i], used_data_indices[start], NumericalDecisionInner, node, used_data_indices[i]);
199
200
      });
    } else {
Guolin Ke's avatar
Guolin Ke committed
201
      Threading::For<data_size_t>(0, num_data, 512, [this, &data, score, used_data_indices, &default_bins, &max_bins]
202
203
      (int, data_size_t start, data_size_t end) {
        PredictionFun(data->num_features(), i, used_data_indices[start], NumericalDecisionInner, split_feature_inner_[node], used_data_indices[i]);
204
205
      });
    }
Guolin Ke's avatar
Guolin Ke committed
206
  }
Guolin Ke's avatar
Guolin Ke committed
207
208
}

209
210
#undef PredictionFun

211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
double Tree::GetUpperBoundValue() const {
  double upper_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] > upper_bound) {
      upper_bound = leaf_value_[i];
    }
  }
  return upper_bound;
}

double Tree::GetLowerBoundValue() const {
  double lower_bound = leaf_value_[0];
  for (int i = 1; i < num_leaves_; ++i) {
    if (leaf_value_[i] < lower_bound) {
      lower_bound = leaf_value_[i];
    }
  }
  return lower_bound;
}

Guolin Ke's avatar
Guolin Ke committed
231
std::string Tree::ToString() const {
232
  std::stringstream str_buf;
233
234
  str_buf << "num_leaves=" << num_leaves_ << '\n';
  str_buf << "num_cat=" << num_cat_ << '\n';
235
  str_buf << "split_feature="
236
    << Common::ArrayToStringFast(split_feature_, num_leaves_ - 1) << '\n';
237
  str_buf << "split_gain="
238
    << Common::ArrayToStringFast(split_gain_, num_leaves_ - 1) << '\n';
239
  str_buf << "threshold="
240
    << Common::ArrayToString(threshold_, num_leaves_ - 1) << '\n';
241
  str_buf << "decision_type="
242
    << Common::ArrayToStringFast(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1) << '\n';
243
  str_buf << "left_child="
244
    << Common::ArrayToStringFast(left_child_, num_leaves_ - 1) << '\n';
245
  str_buf << "right_child="
246
    << Common::ArrayToStringFast(right_child_, num_leaves_ - 1) << '\n';
247
  str_buf << "leaf_value="
248
    << Common::ArrayToString(leaf_value_, num_leaves_) << '\n';
249
250
  str_buf << "leaf_weight="
    << Common::ArrayToString(leaf_weight_, num_leaves_) << '\n';
251
  str_buf << "leaf_count="
252
    << Common::ArrayToStringFast(leaf_count_, num_leaves_) << '\n';
253
  str_buf << "internal_value="
254
    << Common::ArrayToStringFast(internal_value_, num_leaves_ - 1) << '\n';
255
256
  str_buf << "internal_weight="
    << Common::ArrayToStringFast(internal_weight_, num_leaves_ - 1) << '\n';
257
  str_buf << "internal_count="
258
    << Common::ArrayToStringFast(internal_count_, num_leaves_ - 1) << '\n';
259
260
  if (num_cat_ > 0) {
    str_buf << "cat_boundaries="
261
      << Common::ArrayToStringFast(cat_boundaries_, num_cat_ + 1) << '\n';
262
    str_buf << "cat_threshold="
263
      << Common::ArrayToStringFast(cat_threshold_, cat_threshold_.size()) << '\n';
264
  }
265
266
  str_buf << "shrinkage=" << shrinkage_ << '\n';
  str_buf << '\n';
267
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
268
269
}

Guolin Ke's avatar
Guolin Ke committed
270
std::string Tree::ToJSON() const {
271
  std::stringstream str_buf;
272
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
273
274
275
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << '\n';
  str_buf << "\"num_cat\":" << num_cat_ << "," << '\n';
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << '\n';
wxchan's avatar
wxchan committed
276
  if (num_leaves_ == 1) {
277
    str_buf << "\"tree_structure\":{" << "\"leaf_value\":" << leaf_value_[0] << "}" << '\n';
wxchan's avatar
wxchan committed
278
  } else {
279
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << '\n';
wxchan's avatar
wxchan committed
280
  }
wxchan's avatar
wxchan committed
281

282
  return str_buf.str();
wxchan's avatar
wxchan committed
283
284
}

Guolin Ke's avatar
Guolin Ke committed
285
std::string Tree::NodeToJSON(int index) const {
286
  std::stringstream str_buf;
287
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
288
289
  if (index >= 0) {
    // non-leaf
290
291
292
    str_buf << "{" << '\n';
    str_buf << "\"split_index\":" << index << "," << '\n';
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
Guolin Ke's avatar
Guolin Ke committed
293
    str_buf << "\"split_gain\":" << Common::AvoidInf(split_gain_[index]) << "," << '\n';
294
    if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
295
296
297
298
299
300
301
302
303
304
305
      int cat_idx = static_cast<int>(threshold_[index]);
      std::vector<int> cats;
      for (int i = cat_boundaries_[cat_idx]; i < cat_boundaries_[cat_idx + 1]; ++i) {
        for (int j = 0; j < 32; ++j) {
          int cat = (i - cat_boundaries_[cat_idx]) * 32 + j;
          if (Common::FindInBitset(cat_threshold_.data() + cat_boundaries_[cat_idx],
                                   cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx], cat)) {
            cats.push_back(cat);
          }
        }
      }
306
307
      str_buf << "\"threshold\":\"" << Common::Join(cats, "||") << "\"," << '\n';
      str_buf << "\"decision_type\":\"==\"," << '\n';
308
    } else {
309
310
      str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << '\n';
      str_buf << "\"decision_type\":\"<=\"," << '\n';
311
312
    }
    if (GetDecisionType(decision_type_[index], kDefaultLeftMask)) {
313
      str_buf << "\"default_left\":true," << '\n';
314
    } else {
315
      str_buf << "\"default_left\":false," << '\n';
316
317
318
    }
    uint8_t missing_type = GetMissingType(decision_type_[index]);
    if (missing_type == 0) {
319
      str_buf << "\"missing_type\":\"None\"," << '\n';
320
    } else if (missing_type == 1) {
321
      str_buf << "\"missing_type\":\"Zero\"," << '\n';
322
    } else {
323
      str_buf << "\"missing_type\":\"NaN\"," << '\n';
324
    }
325
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << '\n';
326
    str_buf << "\"internal_weight\":" << internal_weight_[index] << "," << '\n';
327
328
329
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << '\n';
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << '\n';
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << '\n';
330
    str_buf << "}";
wxchan's avatar
wxchan committed
331
332
333
  } else {
    // leaf
    index = ~index;
334
335
336
    str_buf << "{" << '\n';
    str_buf << "\"leaf_index\":" << index << "," << '\n';
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << '\n';
337
    str_buf << "\"leaf_weight\":" << leaf_weight_[index] << "," << '\n';
338
    str_buf << "\"leaf_count\":" << leaf_count_[index] << '\n';
339
    str_buf << "}";
wxchan's avatar
wxchan committed
340
341
  }

342
  return str_buf.str();
wxchan's avatar
wxchan committed
343
344
}

Guolin Ke's avatar
Guolin Ke committed
345
346
347
348
std::string Tree::NumericalDecisionIfElse(int node) const {
  std::stringstream str_buf;
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
Guolin Ke's avatar
Guolin Ke committed
349
  if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) {
Guolin Ke's avatar
Guolin Ke committed
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
    str_buf << "if (fval <= " << threshold_[node] << ") {";
  } else if (missing_type == 1) {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !Tree::IsZero(fval)" << " && !std::isnan(fval)) {";
    }
  } else {
    if (default_left) {
      str_buf << "if (fval <= " << threshold_[node] << " || std::isnan(fval)) {";
    } else {
      str_buf << "if (fval <= " << threshold_[node] << " && !std::isnan(fval)) {";
    }
  }
  return str_buf.str();
}

std::string Tree::CategoricalDecisionIfElse(int node) const {
  uint8_t missing_type = GetMissingType(decision_type_[node]);
  std::stringstream str_buf;
  if (missing_type == 2) {
    str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
  } else {
    str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
  }
375
  int cat_idx = static_cast<int>(threshold_[node]);
376
377
378
379
  str_buf << "if (int_fval >= 0 && int_fval < 32 * (";
  str_buf << cat_boundaries_[cat_idx + 1] - cat_boundaries_[cat_idx];
  str_buf << ") && (((cat_threshold[" << cat_boundaries_[cat_idx];
  str_buf << " + int_fval / 32] >> (int_fval & 31)) & 1))) {";
Guolin Ke's avatar
Guolin Ke committed
380
381
382
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
383
std::string Tree::ToIfElse(int index, bool predict_leaf_index) const {
384
385
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
386
  if (predict_leaf_index) {
387
388
389
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
390
391
  if (num_leaves_ <= 1) {
    str_buf << "return " << leaf_value_[0] << ";";
392
  } else {
393
394
395
396
397
398
399
400
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
401
402
403
404
405
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
406
    str_buf << NodeToIfElse(0, predict_leaf_index);
407
  }
408
  str_buf << " }" << '\n';
409

410
  // Predict func by Map to ifelse
411
  str_buf << "double PredictTree" << index;
Guolin Ke's avatar
Guolin Ke committed
412
  if (predict_leaf_index) {
Guolin Ke's avatar
Guolin Ke committed
413
414
415
    str_buf << "LeafByMap";
  } else {
    str_buf << "ByMap";
416
417
418
  }
  str_buf << "(const std::unordered_map<int, double>& arr) { ";
  if (num_leaves_ <= 1) {
Guolin Ke's avatar
Guolin Ke committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
    str_buf << "return " << leaf_value_[0] << ";";
  } else {
    str_buf << "const std::vector<uint32_t> cat_threshold = {";
    for (size_t i = 0; i < cat_threshold_.size(); ++i) {
      if (i != 0) {
        str_buf << ",";
      }
      str_buf << cat_threshold_[i];
    }
    str_buf << "};";
    // use this for the missing value conversion
    str_buf << "double fval = 0.0f; ";
    if (num_cat_ > 0) {
      str_buf << "int int_fval = 0; ";
    }
Guolin Ke's avatar
Guolin Ke committed
434
    str_buf << NodeToIfElseByMap(0, predict_leaf_index);
435
  }
436
  str_buf << " }" << '\n';
437

438
439
440
  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
441
std::string Tree::NodeToIfElse(int index, bool predict_leaf_index) const {
442
443
444
445
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
446
    str_buf << "fval = arr[" << split_feature_[index] << "];";
Guolin Ke's avatar
Guolin Ke committed
447
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
448
      str_buf << NumericalDecisionIfElse(index);
449
    } else {
450
      str_buf << CategoricalDecisionIfElse(index);
451
452
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
453
    str_buf << NodeToIfElse(left_child_[index], predict_leaf_index);
454
455
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
456
    str_buf << NodeToIfElse(right_child_[index], predict_leaf_index);
457
458
459
460
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
461
    if (predict_leaf_index) {
462
463
464
465
466
467
468
469
470
471
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
472
std::string Tree::NodeToIfElseByMap(int index, bool predict_leaf_index) const {
473
474
475
476
477
478
479
480
481
482
483
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
    str_buf << "fval = arr.count(" << split_feature_[index] << ") > 0 ? arr.at(" << split_feature_[index] << ") : 0.0f;";
    if (GetDecisionType(decision_type_[index], kCategoricalMask) == 0) {
      str_buf << NumericalDecisionIfElse(index);
    } else {
      str_buf << CategoricalDecisionIfElse(index);
    }
    // left subtree
Guolin Ke's avatar
Guolin Ke committed
484
    str_buf << NodeToIfElseByMap(left_child_[index], predict_leaf_index);
485
486
    str_buf << " } else { ";
    // right subtree
Guolin Ke's avatar
Guolin Ke committed
487
    str_buf << NodeToIfElseByMap(right_child_[index], predict_leaf_index);
488
489
490
491
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
Guolin Ke's avatar
Guolin Ke committed
492
    if (predict_leaf_index) {
493
494
495
496
497
498
499
500
501
502
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

503
504
Tree::Tree(const char* str, size_t* used_len) {
  auto p = str;
Guolin Ke's avatar
Guolin Ke committed
505
  std::unordered_map<std::string, std::string> key_vals;
506
  const int max_num_line = 17;
507
508
509
510
  int read_line = 0;
  while (read_line < max_num_line) {
    if (*p == '\r' || *p == '\n') break;
    auto start = p;
511
    while (*p != '=') ++p;
512
513
514
515
516
517
518
519
520
521
522
    std::string key(start, p - start);
    ++p;
    start = p;
    while (*p != '\r' && *p != '\n') ++p;
    key_vals[key] = std::string(start, p - start);
    ++read_line;
    if (*p == '\r') ++p;
    if (*p == '\n') ++p;
  }
  *used_len = p - str;

523
  if (key_vals.count("num_leaves") <= 0) {
524
    Log::Fatal("Tree model should contain num_leaves field");
Guolin Ke's avatar
Guolin Ke committed
525
526
527
528
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

529
  if (key_vals.count("num_cat") <= 0) {
530
    Log::Fatal("Tree model should contain num_cat field");
531
532
533
534
  }

  Common::Atoi(key_vals["num_cat"].c_str(), &num_cat_);

535
  if (key_vals.count("leaf_value")) {
536
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], num_leaves_);
537
538
539
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }
540

Guolin Ke's avatar
Guolin Ke committed
541
542
543
544
545
  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
546

547
548
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
549
  if (key_vals.count("left_child")) {
550
    left_child_ = Common::StringToArrayFast<int>(key_vals["left_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
551
552
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
553
554
  }

Guolin Ke's avatar
Guolin Ke committed
555
  if (key_vals.count("right_child")) {
556
    right_child_ = Common::StringToArrayFast<int>(key_vals["right_child"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
557
558
559
560
561
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
562
    split_feature_ = Common::StringToArrayFast<int>(key_vals["split_feature"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
563
564
565
566
567
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
568
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
569
570
571
572
573
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

  if (key_vals.count("split_gain")) {
574
    split_gain_ = Common::StringToArrayFast<float>(key_vals["split_gain"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
575
576
577
578
579
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
580
    internal_count_ = Common::StringToArrayFast<int>(key_vals["internal_count"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
581
582
583
584
585
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
586
    internal_value_ = Common::StringToArrayFast<double>(key_vals["internal_value"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
587
588
589
590
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

591
592
  if (key_vals.count("internal_weight")) {
    internal_weight_ = Common::StringToArrayFast<double>(key_vals["internal_weight"], num_leaves_ - 1);
593
  } else {
594
595
596
597
598
    internal_weight_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_weight")) {
    leaf_weight_ = Common::StringToArrayFast<double>(key_vals["leaf_weight"], num_leaves_);
599
  } else {
600
601
602
    leaf_weight_.resize(num_leaves_);
  }

Guolin Ke's avatar
Guolin Ke committed
603
  if (key_vals.count("leaf_count")) {
604
    leaf_count_ = Common::StringToArrayFast<int>(key_vals["leaf_count"], num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
605
606
607
608
609
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
610
    decision_type_ = Common::StringToArrayFast<int8_t>(key_vals["decision_type"], num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
611
612
613
614
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

615
616
  if (num_cat_ > 0) {
    if (key_vals.count("cat_boundaries")) {
617
      cat_boundaries_ = Common::StringToArrayFast<int>(key_vals["cat_boundaries"], num_cat_ + 1);
618
619
620
621
622
    } else {
      Log::Fatal("Tree model should contain cat_boundaries field.");
    }

    if (key_vals.count("cat_threshold")) {
623
      cat_threshold_ = Common::StringToArrayFast<uint32_t>(key_vals["cat_threshold"], cat_boundaries_.back());
624
    } else {
625
      Log::Fatal("Tree model should contain cat_threshold field");
626
627
    }
  }
628
  max_depth_ = -1;
Guolin Ke's avatar
Guolin Ke committed
629
630
}

Guolin Ke's avatar
Guolin Ke committed
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
void Tree::ExtendPath(PathElement *unique_path, int unique_depth,
                      double zero_fraction, double one_fraction, int feature_index) {
  unique_path[unique_depth].feature_index = feature_index;
  unique_path[unique_depth].zero_fraction = zero_fraction;
  unique_path[unique_depth].one_fraction = one_fraction;
  unique_path[unique_depth].pweight = (unique_depth == 0 ? 1 : 0);
  for (int i = unique_depth - 1; i >= 0; i--) {
    unique_path[i + 1].pweight += one_fraction*unique_path[i].pweight*(i + 1)
      / static_cast<double>(unique_depth + 1);
    unique_path[i].pweight = zero_fraction*unique_path[i].pweight*(unique_depth - i)
      / static_cast<double>(unique_depth + 1);
  }
}

void Tree::UnwindPath(PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;

  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = unique_path[i].pweight;
      unique_path[i].pweight = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      next_one_portion = tmp - unique_path[i].pweight*zero_fraction*(unique_depth - i)
        / static_cast<double>(unique_depth + 1);
    } else {
      unique_path[i].pweight = (unique_path[i].pweight*(unique_depth + 1))
        / static_cast<double>(zero_fraction*(unique_depth - i));
    }
  }

  for (int i = path_index; i < unique_depth; ++i) {
    unique_path[i].feature_index = unique_path[i + 1].feature_index;
    unique_path[i].zero_fraction = unique_path[i + 1].zero_fraction;
    unique_path[i].one_fraction = unique_path[i + 1].one_fraction;
  }
}

double Tree::UnwoundPathSum(const PathElement *unique_path, int unique_depth, int path_index) {
  const double one_fraction = unique_path[path_index].one_fraction;
  const double zero_fraction = unique_path[path_index].zero_fraction;
  double next_one_portion = unique_path[unique_depth].pweight;
  double total = 0;
  for (int i = unique_depth - 1; i >= 0; --i) {
    if (one_fraction != 0) {
      const double tmp = next_one_portion*(unique_depth + 1)
        / static_cast<double>((i + 1)*one_fraction);
      total += tmp;
      next_one_portion = unique_path[i].pweight - tmp*zero_fraction*((unique_depth - i)
                                                                     / static_cast<double>(unique_depth + 1));
    } else {
      total += (unique_path[i].pweight / zero_fraction) / ((unique_depth - i)
                                                           / static_cast<double>(unique_depth + 1));
    }
  }
  return total;
}

// recursive computation of SHAP values for a decision tree
void Tree::TreeSHAP(const double *feature_values, double *phi,
                    int node, int unique_depth,
                    PathElement *parent_unique_path, double parent_zero_fraction,
                    double parent_one_fraction, int parent_feature_index) const {
  // extend the unique path
696
  PathElement* unique_path = parent_unique_path + unique_depth;
Guolin Ke's avatar
Guolin Ke committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
  if (unique_depth > 0) std::copy(parent_unique_path, parent_unique_path + unique_depth, unique_path);
  ExtendPath(unique_path, unique_depth, parent_zero_fraction,
             parent_one_fraction, parent_feature_index);

  // leaf node
  if (node < 0) {
    for (int i = 1; i <= unique_depth; ++i) {
      const double w = UnwoundPathSum(unique_path, unique_depth, i);
      const PathElement &el = unique_path[i];
      phi[el.feature_index] += w*(el.one_fraction - el.zero_fraction)*leaf_value_[~node];
    }

    // internal node
  } else {
711
    const int hot_index = Decision(feature_values[split_feature_[node]], node);
Guolin Ke's avatar
Guolin Ke committed
712
713
714
715
716
717
718
719
720
721
722
    const int cold_index = (hot_index == left_child_[node] ? right_child_[node] : left_child_[node]);
    const double w = data_count(node);
    const double hot_zero_fraction = data_count(hot_index) / w;
    const double cold_zero_fraction = data_count(cold_index) / w;
    double incoming_zero_fraction = 1;
    double incoming_one_fraction = 1;

    // see if we have already split on this feature,
    // if so we undo that split so we can redo it for this node
    int path_index = 0;
    for (; path_index <= unique_depth; ++path_index) {
723
      if (unique_path[path_index].feature_index == split_feature_[node]) break;
Guolin Ke's avatar
Guolin Ke committed
724
725
726
727
728
729
730
731
732
    }
    if (path_index != unique_depth + 1) {
      incoming_zero_fraction = unique_path[path_index].zero_fraction;
      incoming_one_fraction = unique_path[path_index].one_fraction;
      UnwindPath(unique_path, unique_depth, path_index);
      unique_depth -= 1;
    }

    TreeSHAP(feature_values, phi, hot_index, unique_depth + 1, unique_path,
733
             hot_zero_fraction*incoming_zero_fraction, incoming_one_fraction, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
734
735

    TreeSHAP(feature_values, phi, cold_index, unique_depth + 1, unique_path,
736
             cold_zero_fraction*incoming_zero_fraction, 0, split_feature_[node]);
Guolin Ke's avatar
Guolin Ke committed
737
738
739
  }
}

740
741
742
743
744
double Tree::ExpectedValue() const {
  if (num_leaves_ == 1) return LeafOutput(0);
  const double total_count = internal_count_[0];
  double exp_value = 0.0;
  for (int i = 0; i < num_leaves(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
745
    exp_value += (leaf_count_[i] / total_count)*LeafOutput(i);
Guolin Ke's avatar
Guolin Ke committed
746
  }
747
  return exp_value;
Guolin Ke's avatar
Guolin Ke committed
748
749
}

750
751
752
753
754
755
756
757
758
759
760
void Tree::RecomputeMaxDepth() {
  if (num_leaves_ == 1) {
    max_depth_ = 0;
  } else {
    if (leaf_depth_.size() == 0) {
      RecomputeLeafDepths(0, 0);
    }
    max_depth_ = leaf_depth_[0];
    for (int i = 1; i < num_leaves(); ++i) {
      if (max_depth_ < leaf_depth_[i]) max_depth_ = leaf_depth_[i];
    }
Guolin Ke's avatar
Guolin Ke committed
761
762
763
  }
}

Guolin Ke's avatar
Guolin Ke committed
764
}  // namespace LightGBM