"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "64e52093b37d710836b74ce4291013a46d5a0dec"
tree.cpp 20.3 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/tree.h>

#include <LightGBM/utils/threading.h>
#include <LightGBM/utils/common.h>

#include <LightGBM/dataset.h>

#include <sstream>
#include <unordered_map>
#include <functional>
#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
13
#include <memory>
14
#include <iomanip>
Guolin Ke's avatar
Guolin Ke committed
15
16
17

namespace LightGBM {

18
19
20
21
22
std::vector<bool(*)(uint32_t, uint32_t)> Tree::inner_decision_funs =
{ Tree::NumericalDecision<uint32_t>, Tree::CategoricalDecision<uint32_t> };
std::vector<bool(*)(double, double)> Tree::decision_funs =
{ Tree::NumericalDecision<double>, Tree::CategoricalDecision<double> };

Guolin Ke's avatar
Guolin Ke committed
23
24
25
Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {

Guolin Ke's avatar
Guolin Ke committed
26
  num_leaves_ = 0;
Guolin Ke's avatar
Guolin Ke committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
  left_child_.resize(max_leaves_ - 1);
  right_child_.resize(max_leaves_ - 1);
  split_feature_inner_.resize(max_leaves_ - 1);
  split_feature_.resize(max_leaves_ - 1);
  threshold_in_bin_.resize(max_leaves_ - 1);
  threshold_.resize(max_leaves_ - 1);
  decision_type_.resize(max_leaves_ - 1);
  default_value_.resize(max_leaves_ - 1);
  zero_bin_.resize(max_leaves_ - 1);
  default_bin_for_zero_.resize(max_leaves_ - 1);
  split_gain_.resize(max_leaves_ - 1);
  leaf_parent_.resize(max_leaves_);
  leaf_value_.resize(max_leaves_);
  leaf_count_.resize(max_leaves_);
  internal_value_.resize(max_leaves_ - 1);
  internal_count_.resize(max_leaves_ - 1);
  leaf_depth_.resize(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
44
45
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
46
47
  num_leaves_ = 1;
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
48
  shrinkage_ = 1.0f;
49
  has_categorical_ = false;
Guolin Ke's avatar
Guolin Ke committed
50
51
}
Tree::~Tree() {
Guolin Ke's avatar
Guolin Ke committed
52

Guolin Ke's avatar
Guolin Ke committed
53
54
}

Guolin Ke's avatar
Guolin Ke committed
55
56
57
int Tree::Split(int leaf, int feature, BinType bin_type, uint32_t threshold_bin, int real_feature, double threshold_double, 
                double left_value, double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain,
                uint32_t zero_bin, uint32_t default_bin_for_zero, double default_value) {
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
66
67
68
69
  int new_node_idx = num_leaves_ - 1;
  // update parent info
  int parent = leaf_parent_[leaf];
  if (parent >= 0) {
    // if cur node is left child
    if (left_child_[parent] == ~leaf) {
      left_child_[parent] = new_node_idx;
    } else {
      right_child_[parent] = new_node_idx;
    }
  }
  // add new node
Guolin Ke's avatar
Guolin Ke committed
70
  split_feature_inner_[new_node_idx] = feature;
Guolin Ke's avatar
Guolin Ke committed
71
  split_feature_[new_node_idx] = real_feature;
Guolin Ke's avatar
Guolin Ke committed
72
73
74
75
76

  zero_bin_[new_node_idx] = zero_bin;
  default_bin_for_zero_[new_node_idx] = default_bin_for_zero;
  default_value_[new_node_idx] = Common::AvoidInf(default_value);

77
78
79
80
81
82
  if (bin_type == BinType::NumericalBin) {
    decision_type_[new_node_idx] = 0;
  } else {
    has_categorical_ = true;
    decision_type_[new_node_idx] = 1;
  }
Guolin Ke's avatar
Guolin Ke committed
83

Guolin Ke's avatar
Guolin Ke committed
84
  threshold_in_bin_[new_node_idx] = threshold_bin;
85
  threshold_[new_node_idx] = Common::AvoidInf(threshold_double);
Guolin Ke's avatar
Guolin Ke committed
86
  split_gain_[new_node_idx] = Common::AvoidInf(gain);
Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
91
92
  // add two new leaves
  left_child_[new_node_idx] = ~leaf;
  right_child_[new_node_idx] = ~num_leaves_;
  // update new leaves
  leaf_parent_[leaf] = new_node_idx;
  leaf_parent_[num_leaves_] = new_node_idx;
93
94
  // save current leaf value to internal node before change
  internal_value_[new_node_idx] = leaf_value_[leaf];
Guolin Ke's avatar
Guolin Ke committed
95
  internal_count_[new_node_idx] = left_cnt + right_cnt;
Guolin Ke's avatar
Guolin Ke committed
96
  leaf_value_[leaf] = std::isnan(left_value) ? 0.0f : left_value;
Guolin Ke's avatar
Guolin Ke committed
97
  leaf_count_[leaf] = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
98
  leaf_value_[num_leaves_] = std::isnan(right_value) ? 0.0f : right_value;
Guolin Ke's avatar
Guolin Ke committed
99
  leaf_count_[num_leaves_] = right_cnt;
Guolin Ke's avatar
Guolin Ke committed
100
101
102
  // update leaf depth
  leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
  leaf_depth_[leaf]++;
Guolin Ke's avatar
Guolin Ke committed
103
104
105
106
107

  ++num_leaves_;
  return num_leaves_ - 1;
}

108
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
109
  if (num_leaves_ <= 1) { return; }
110
111
112
113
114
115
  if (has_categorical_) {
    if (data->num_features() > num_leaves_ - 1) {
      Threading::For<data_size_t>(0, num_data,
        [this, &data, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(num_leaves_ - 1);
        for (int i = 0; i < num_leaves_ - 1; ++i) {
Guolin Ke's avatar
Guolin Ke committed
116
          const int fidx = split_feature_inner_[i];
117
118
119
120
121
122
          iter[i].reset(data->FeatureIterator(fidx));
          iter[i]->Reset(start);
        }
        for (data_size_t i = start; i < end; ++i) {
          int node = 0;
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
123
            uint32_t fval = DefaultValueForZero(iter[node]->Get(i), zero_bin_[node], default_bin_for_zero_[node]);
124
            if (inner_decision_funs[decision_type_[node]](
Guolin Ke's avatar
Guolin Ke committed
125
              fval,
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
              threshold_in_bin_[node])) {
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[i] += static_cast<double>(leaf_value_[~node]);
        }
      });
    } else {
      Threading::For<data_size_t>(0, num_data,
        [this, &data, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(data->num_features());
        for (int i = 0; i < data->num_features(); ++i) {
          iter[i].reset(data->FeatureIterator(i));
          iter[i]->Reset(start);
        }
        for (data_size_t i = start; i < end; ++i) {
          int node = 0;
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
146
            uint32_t fval = DefaultValueForZero(iter[split_feature_inner_[node]]->Get(i), zero_bin_[node], default_bin_for_zero_[node]);
147
            if (inner_decision_funs[decision_type_[node]](
Guolin Ke's avatar
Guolin Ke committed
148
              fval,
149
150
151
152
153
154
155
156
157
158
              threshold_in_bin_[node])) {
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[i] += static_cast<double>(leaf_value_[~node]);
        }
      });
    }
Guolin Ke's avatar
Guolin Ke committed
159
  } else {
160
161
162
163
164
    if (data->num_features() > num_leaves_ - 1) {
      Threading::For<data_size_t>(0, num_data,
        [this, &data, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(num_leaves_ - 1);
        for (int i = 0; i < num_leaves_ - 1; ++i) {
Guolin Ke's avatar
Guolin Ke committed
165
          const int fidx = split_feature_inner_[i];
166
167
168
169
170
171
          iter[i].reset(data->FeatureIterator(fidx));
          iter[i]->Reset(start);
        }
        for (data_size_t i = start; i < end; ++i) {
          int node = 0;
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
172
173
            uint32_t fval = DefaultValueForZero(iter[node]->Get(i), zero_bin_[node], default_bin_for_zero_[node]);
            if (fval <= threshold_in_bin_[node]) {
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[i] += static_cast<double>(leaf_value_[~node]);
        }
      });
    } else {
      Threading::For<data_size_t>(0, num_data,
        [this, &data, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(data->num_features());
        for (int i = 0; i < data->num_features(); ++i) {
          iter[i].reset(data->FeatureIterator(i));
          iter[i]->Reset(start);
        }
        for (data_size_t i = start; i < end; ++i) {
          int node = 0;
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
193
194
            uint32_t fval = DefaultValueForZero(iter[split_feature_inner_[node]]->Get(i), zero_bin_[node], default_bin_for_zero_[node]);
            if (fval <= threshold_in_bin_[node]) {
195
196
197
198
199
200
201
202
203
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[i] += static_cast<double>(leaf_value_[~node]);
        }
      });
    }
Guolin Ke's avatar
Guolin Ke committed
204
  }
Guolin Ke's avatar
Guolin Ke committed
205
206
}

Guolin Ke's avatar
Guolin Ke committed
207
208
209
void Tree::AddPredictionToScore(const Dataset* data,
  const data_size_t* used_data_indices,
  data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
210
  if (num_leaves_ <= 1) { return; }
211
212
213
214
215
216
  if (has_categorical_) {
    if (data->num_features() > num_leaves_ - 1) {
      Threading::For<data_size_t>(0, num_data,
        [this, data, used_data_indices, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(num_leaves_ - 1);
        for (int i = 0; i < num_leaves_ - 1; ++i) {
Guolin Ke's avatar
Guolin Ke committed
217
          const int fidx = split_feature_inner_[i];
218
219
220
221
222
223
224
          iter[i].reset(data->FeatureIterator(fidx));
          iter[i]->Reset(used_data_indices[start]);
        }
        for (data_size_t i = start; i < end; ++i) {
          int node = 0;
          const data_size_t idx = used_data_indices[i];
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
225
            uint32_t fval = DefaultValueForZero(iter[node]->Get(idx), zero_bin_[node], default_bin_for_zero_[node]);
226
            if (inner_decision_funs[decision_type_[node]](
Guolin Ke's avatar
Guolin Ke committed
227
              fval,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
              threshold_in_bin_[node])) {
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[idx] += static_cast<double>(leaf_value_[~node]);
        }
      });
    } else {
      Threading::For<data_size_t>(0, num_data,
        [this, data, used_data_indices, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(data->num_features());
        for (int i = 0; i < data->num_features(); ++i) {
          iter[i].reset(data->FeatureIterator(i));
          iter[i]->Reset(used_data_indices[start]);
        }
        for (data_size_t i = start; i < end; ++i) {
          const data_size_t idx = used_data_indices[i];
          int node = 0;
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
249
            uint32_t fval = DefaultValueForZero(iter[split_feature_inner_[node]]->Get(idx), zero_bin_[node], default_bin_for_zero_[node]);
250
            if (inner_decision_funs[decision_type_[node]](
Guolin Ke's avatar
Guolin Ke committed
251
              fval,
252
253
254
255
256
257
258
259
260
261
              threshold_in_bin_[node])) {
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[idx] += static_cast<double>(leaf_value_[~node]);
        }
      });
    }
Guolin Ke's avatar
Guolin Ke committed
262
  } else {
263
264
265
266
267
    if (data->num_features() > num_leaves_ - 1) {
      Threading::For<data_size_t>(0, num_data,
        [this, data, used_data_indices, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(num_leaves_ - 1);
        for (int i = 0; i < num_leaves_ - 1; ++i) {
Guolin Ke's avatar
Guolin Ke committed
268
          const int fidx = split_feature_inner_[i];
269
270
271
272
273
274
275
          iter[i].reset(data->FeatureIterator(fidx));
          iter[i]->Reset(used_data_indices[start]);
        }
        for (data_size_t i = start; i < end; ++i) {
          int node = 0;
          const data_size_t idx = used_data_indices[i];
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
276
277
            uint32_t fval = DefaultValueForZero(iter[node]->Get(idx), zero_bin_[node], default_bin_for_zero_[node]);
            if (fval <= threshold_in_bin_[node]) {
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[idx] += static_cast<double>(leaf_value_[~node]);
        }
      });
    } else {
      Threading::For<data_size_t>(0, num_data,
        [this, data, used_data_indices, score](int, data_size_t start, data_size_t end) {
        std::vector<std::unique_ptr<BinIterator>> iter(data->num_features());
        for (int i = 0; i < data->num_features(); ++i) {
          iter[i].reset(data->FeatureIterator(i));
          iter[i]->Reset(used_data_indices[start]);
        }
        for (data_size_t i = start; i < end; ++i) {
          const data_size_t idx = used_data_indices[i];
          int node = 0;
          while (node >= 0) {
Guolin Ke's avatar
Guolin Ke committed
298
299
            uint32_t fval = DefaultValueForZero(iter[split_feature_inner_[node]]->Get(idx), zero_bin_[node], default_bin_for_zero_[node]);
            if (fval <= threshold_in_bin_[node]) {
300
301
302
303
304
305
306
307
308
              node = left_child_[node];
            } else {
              node = right_child_[node];
            }
          }
          score[idx] += static_cast<double>(leaf_value_[~node]);
        }
      });
    }
Guolin Ke's avatar
Guolin Ke committed
309
  }
Guolin Ke's avatar
Guolin Ke committed
310
311
312
}

std::string Tree::ToString() {
313
314
315
  std::stringstream str_buf;
  str_buf << "num_leaves=" << num_leaves_ << std::endl;
  str_buf << "split_feature="
Guolin Ke's avatar
Guolin Ke committed
316
    << Common::ArrayToString<int>(split_feature_, num_leaves_ - 1, ' ') << std::endl;
317
  str_buf << "split_gain="
Guolin Ke's avatar
Guolin Ke committed
318
    << Common::ArrayToString<double>(split_gain_, num_leaves_ - 1, ' ') << std::endl;
319
  str_buf << "threshold="
Guolin Ke's avatar
Guolin Ke committed
320
    << Common::ArrayToString<double>(threshold_, num_leaves_ - 1, ' ') << std::endl;
321
322
  str_buf << "decision_type="
    << Common::ArrayToString<int>(Common::ArrayCast<int8_t, int>(decision_type_), num_leaves_ - 1, ' ') << std::endl;
Guolin Ke's avatar
Guolin Ke committed
323
324
  str_buf << "default_value="
    << Common::ArrayToString<double>(default_value_, num_leaves_ - 1, ' ') << std::endl;
325
  str_buf << "left_child="
Guolin Ke's avatar
Guolin Ke committed
326
    << Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl;
327
  str_buf << "right_child="
Guolin Ke's avatar
Guolin Ke committed
328
    << Common::ArrayToString<int>(right_child_, num_leaves_ - 1, ' ') << std::endl;
329
  str_buf << "leaf_parent="
Guolin Ke's avatar
Guolin Ke committed
330
    << Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl;
331
  str_buf << "leaf_value="
Guolin Ke's avatar
Guolin Ke committed
332
    << Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl;
333
  str_buf << "leaf_count="
Guolin Ke's avatar
Guolin Ke committed
334
    << Common::ArrayToString<data_size_t>(leaf_count_, num_leaves_, ' ') << std::endl;
335
  str_buf << "internal_value="
Guolin Ke's avatar
Guolin Ke committed
336
    << Common::ArrayToString<double>(internal_value_, num_leaves_ - 1, ' ') << std::endl;
337
  str_buf << "internal_count="
Guolin Ke's avatar
Guolin Ke committed
338
    << Common::ArrayToString<data_size_t>(internal_count_, num_leaves_ - 1, ' ') << std::endl;
Guolin Ke's avatar
Guolin Ke committed
339
  str_buf << "shrinkage=" << shrinkage_ << std::endl;
Guolin Ke's avatar
Guolin Ke committed
340
  str_buf << "has_categorical=" << (has_categorical_ ? 1 : 0) << std::endl;
341
342
  str_buf << std::endl;
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
343
344
}

wxchan's avatar
wxchan committed
345
std::string Tree::ToJSON() {
346
  std::stringstream str_buf;
347
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
348
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
349
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
350
  str_buf << "\"has_categorical\":" << (has_categorical_ ? 1 : 0) << "," << std::endl;
wxchan's avatar
wxchan committed
351
352
353
354
355
  if (num_leaves_ == 1) {
    str_buf << "\"tree_structure\":" << NodeToJSON(-1) << std::endl;
  } else {
    str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl;
  }
wxchan's avatar
wxchan committed
356

357
  return str_buf.str();
wxchan's avatar
wxchan committed
358
359
360
}

std::string Tree::NodeToJSON(int index) {
361
  std::stringstream str_buf;
362
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
363
364
  if (index >= 0) {
    // non-leaf
365
366
    str_buf << "{" << std::endl;
    str_buf << "\"split_index\":" << index << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
367
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << std::endl;
368
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << std::endl;
369
    str_buf << "\"threshold\":" << Common::AvoidInf(threshold_[index]) << "," << std::endl;
370
    str_buf << "\"decision_type\":\"" << Tree::GetDecisionTypeName(decision_type_[index]) << "\"," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
371
    str_buf << "\"default_value\":" << default_value_[index] << "," << std::endl;
372
373
374
375
376
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << std::endl;
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << std::endl;
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << std::endl;
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << std::endl;
    str_buf << "}";
wxchan's avatar
wxchan committed
377
378
379
  } else {
    // leaf
    index = ~index;
380
381
382
383
384
385
    str_buf << "{" << std::endl;
    str_buf << "\"leaf_index\":" << index << "," << std::endl;
    str_buf << "\"leaf_parent\":" << leaf_parent_[index] << "," << std::endl;
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << std::endl;
    str_buf << "\"leaf_count\":" << leaf_count_[index] << std::endl;
    str_buf << "}";
wxchan's avatar
wxchan committed
386
387
  }

388
  return str_buf.str();
wxchan's avatar
wxchan committed
389
390
}

391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
std::string Tree::ToIfElse(int index, bool is_predict_leaf_index) {
  std::stringstream str_buf;
  str_buf << "double PredictTree" << index;
  if (is_predict_leaf_index) {
    str_buf << "Leaf";
  }
  str_buf << "(const double* arr) { ";
  if (num_leaves_ == 1) {
    str_buf << "return 0";
  } else {
    str_buf << NodeToIfElse(0, is_predict_leaf_index);
  }
  str_buf << " }" << std::endl;
  return str_buf.str();
}

std::string Tree::NodeToIfElse(int index, bool is_predict_leaf_index) {
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  if (index >= 0) {
    // non-leaf
Guolin Ke's avatar
Guolin Ke committed
412
413
414
415
416
    std::stringstream tmp_str_buf;
    tmp_str_buf << "arr[" << split_feature_[index] << "]";
    std::string str_fval = tmp_str_buf.str();
    str_buf << "if( ( " << str_fval <<" <= " << kMissingValueRange  << " && "<< str_fval << " > -" << kMissingValueRange <<" ?  "
      << default_value_[index] << " : " << str_fval << " ) ";
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
    if (decision_type_[index] == 0) {
      str_buf << "<";
    } else {
      str_buf << "=";
    }
    str_buf << "= " << threshold_[index] << " ) { ";
    // left subtree
    str_buf << NodeToIfElse(left_child_[index], is_predict_leaf_index);
    str_buf << " } else { ";
    // right subtree
    str_buf << NodeToIfElse(right_child_[index], is_predict_leaf_index);
    str_buf << " }";
  } else {
    // leaf
    str_buf << "return ";
    if (is_predict_leaf_index) {
      str_buf << ~index;
    } else {
      str_buf << leaf_value_[~index];
    }
    str_buf << ";";
  }

  return str_buf.str();
}

Guolin Ke's avatar
Guolin Ke committed
443
Tree::Tree(const std::string& str) {
Guolin Ke's avatar
Guolin Ke committed
444
  std::vector<std::string> lines = Common::SplitLines(str.c_str());
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448
449
450
451
452
453
454
455
  std::unordered_map<std::string, std::string> key_vals;
  for (const std::string& line : lines) {
    std::vector<std::string> tmp_strs = Common::Split(line.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::Trim(tmp_strs[0]);
      std::string val = Common::Trim(tmp_strs[1]);
      if (key.size() > 0 && val.size() > 0) {
        key_vals[key] = val;
      }
    }
  }
456
  if (key_vals.count("num_leaves") <= 0) {
Guolin Ke's avatar
Guolin Ke committed
457
    Log::Fatal("Tree model should contain num_leaves field.");
Guolin Ke's avatar
Guolin Ke committed
458
459
460
461
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

462
463
  if (num_leaves_ <= 1) { return; }

Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
  if (key_vals.count("left_child")) {
    left_child_ = Common::StringToArray<int>(key_vals["left_child"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain left_child field");
468
469
  }

Guolin Ke's avatar
Guolin Ke committed
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
  if (key_vals.count("right_child")) {
    right_child_ = Common::StringToArray<int>(key_vals["right_child"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain right_child field");
  }

  if (key_vals.count("split_feature")) {
    split_feature_ = Common::StringToArray<int>(key_vals["split_feature"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain split_feature field");
  }

  if (key_vals.count("threshold")) {
    threshold_ = Common::StringToArray<double>(key_vals["threshold"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain threshold field");
  }

Guolin Ke's avatar
Guolin Ke committed
488
489
490
491
492
493
  if (key_vals.count("default_value")) {
    default_value_ = Common::StringToArray<double>(key_vals["default_value"], ' ', num_leaves_ - 1);
  } else {
    Log::Fatal("Tree model string format error, should contain default_value field");
  }

Guolin Ke's avatar
Guolin Ke committed
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
  if (key_vals.count("leaf_value")) {
    leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
  } else {
    Log::Fatal("Tree model string format error, should contain leaf_value field");
  }

  if (key_vals.count("split_gain")) {
    split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
  } else {
    split_gain_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_count")) {
    internal_count_ = Common::StringToArray<data_size_t>(key_vals["internal_count"], ' ', num_leaves_ - 1);
  } else {
    internal_count_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("internal_value")) {
    internal_value_ = Common::StringToArray<double>(key_vals["internal_value"], ' ', num_leaves_ - 1);
  } else {
    internal_value_.resize(num_leaves_ - 1);
  }

  if (key_vals.count("leaf_count")) {
    leaf_count_ = Common::StringToArray<data_size_t>(key_vals["leaf_count"], ' ', num_leaves_);
  } else {
    leaf_count_.resize(num_leaves_);
  }

  if (key_vals.count("leaf_parent")) {
    leaf_parent_ = Common::StringToArray<int>(key_vals["leaf_parent"], ' ', num_leaves_);
  } else {
    leaf_parent_.resize(num_leaves_);
  }

  if (key_vals.count("decision_type")) {
    decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1);
  } else {
    decision_type_ = std::vector<int8_t>(num_leaves_ - 1, 0);
  }

  if (key_vals.count("shrinkage")) {
    Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
  } else {
    shrinkage_ = 1.0f;
  }
Guolin Ke's avatar
Guolin Ke committed
541
542
543
544
545
546
547
548
549

  if (key_vals.count("has_categorical")) {
    int t = 0;
    Common::Atoi(key_vals["has_categorical"].c_str(), &t);
    has_categorical_ = t > 0;
  } else {
    has_categorical_ = false;
  }

Guolin Ke's avatar
Guolin Ke committed
550
551
552
}

}  // namespace LightGBM