tree.cpp 10 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
#include <LightGBM/tree.h>

#include <LightGBM/utils/threading.h>
#include <LightGBM/utils/common.h>

#include <LightGBM/dataset.h>

#include <sstream>
#include <unordered_map>
#include <functional>
#include <vector>
#include <string>
Guolin Ke's avatar
Guolin Ke committed
13
#include <memory>
14
#include <iomanip>
Guolin Ke's avatar
Guolin Ke committed
15
16
17
18
19
20

namespace LightGBM {

Tree::Tree(int max_leaves)
  :max_leaves_(max_leaves) {

Guolin Ke's avatar
Guolin Ke committed
21
22
23
  num_leaves_ = 0;
  left_child_ = std::vector<int>(max_leaves_ - 1);
  right_child_ = std::vector<int>(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
24
  split_feature_inner = std::vector<int>(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
25
  split_feature_ = std::vector<int>(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
26
  threshold_in_bin_ = std::vector<uint32_t>(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
27
28
29
30
  threshold_ = std::vector<double>(max_leaves_ - 1);
  split_gain_ = std::vector<double>(max_leaves_ - 1);
  leaf_parent_ = std::vector<int>(max_leaves_);
  leaf_value_ = std::vector<double>(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
31
  leaf_count_ = std::vector<data_size_t>(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
32
  internal_value_ = std::vector<double>(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
33
  internal_count_ = std::vector<data_size_t>(max_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
34
  leaf_depth_ = std::vector<int>(max_leaves_);
Guolin Ke's avatar
Guolin Ke committed
35
36
  // root is in the depth 0
  leaf_depth_[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
37
38
  num_leaves_ = 1;
  leaf_parent_[0] = -1;
Guolin Ke's avatar
Guolin Ke committed
39
  shrinkage_ = 1.0f;
Guolin Ke's avatar
Guolin Ke committed
40
41
}
Tree::~Tree() {
Guolin Ke's avatar
Guolin Ke committed
42

Guolin Ke's avatar
Guolin Ke committed
43
44
}

Guolin Ke's avatar
Guolin Ke committed
45
int Tree::Split(int leaf, int feature, uint32_t threshold_bin, int real_feature,
Guolin Ke's avatar
Guolin Ke committed
46
47
    double threshold_double, double left_value,
    double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain) {
Guolin Ke's avatar
Guolin Ke committed
48
49
50
51
52
53
54
55
56
57
58
59
  int new_node_idx = num_leaves_ - 1;
  // update parent info
  int parent = leaf_parent_[leaf];
  if (parent >= 0) {
    // if cur node is left child
    if (left_child_[parent] == ~leaf) {
      left_child_[parent] = new_node_idx;
    } else {
      right_child_[parent] = new_node_idx;
    }
  }
  // add new node
Guolin Ke's avatar
Guolin Ke committed
60
61
  split_feature_inner[new_node_idx] = feature;
  split_feature_[new_node_idx] = real_feature;
Guolin Ke's avatar
Guolin Ke committed
62
  threshold_in_bin_[new_node_idx] = threshold_bin;
Guolin Ke's avatar
Guolin Ke committed
63
  threshold_[new_node_idx] = threshold_double;
Guolin Ke's avatar
Guolin Ke committed
64
65
66
67
68
69
70
  split_gain_[new_node_idx] = gain;
  // add two new leaves
  left_child_[new_node_idx] = ~leaf;
  right_child_[new_node_idx] = ~num_leaves_;
  // update new leaves
  leaf_parent_[leaf] = new_node_idx;
  leaf_parent_[num_leaves_] = new_node_idx;
71
72
  // save current leaf value to internal node before change
  internal_value_[new_node_idx] = leaf_value_[leaf];
Guolin Ke's avatar
Guolin Ke committed
73
  internal_count_[new_node_idx] = left_cnt + right_cnt;
Guolin Ke's avatar
Guolin Ke committed
74
  leaf_value_[leaf] = left_value;
Guolin Ke's avatar
Guolin Ke committed
75
  leaf_count_[leaf] = left_cnt;
Guolin Ke's avatar
Guolin Ke committed
76
  leaf_value_[num_leaves_] = right_value;
Guolin Ke's avatar
Guolin Ke committed
77
  leaf_count_[num_leaves_] = right_cnt;
Guolin Ke's avatar
Guolin Ke committed
78
79
80
  // update leaf depth
  leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
  leaf_depth_[leaf]++;
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
85

  ++num_leaves_;
  return num_leaves_ - 1;
}

86
void Tree::AddPredictionToScore(const Dataset* data, data_size_t num_data, double* score) const {
Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
  if (data->num_features() > num_leaves_ - 1) {
    Threading::For<data_size_t>(0, num_data,
      [this, &data, score](int, data_size_t start, data_size_t end) {
      std::vector<std::unique_ptr<BinIterator>> iter(num_leaves_ - 1);
      for (int i = 0; i < num_leaves_ - 1; ++i) {
        const int fidx = split_feature_inner[i];
        iter[i].reset(data->FeatureIterator(fidx));
        iter[i]->Reset(start);
      }
      for (data_size_t i = start; i < end; ++i) {
        score[i] += static_cast<double>(leaf_value_[GetLeaf(iter, i)]);
      }
    });
  } else {
    Threading::For<data_size_t>(0, num_data,
      [this, &data, score](int, data_size_t start, data_size_t end) {
      std::vector<std::unique_ptr<BinIterator>> iter(data->num_features());
      for (int i = 0; i < data->num_features(); ++i) {
        iter[i].reset(data->FeatureIterator(i));
        iter[i]->Reset(start);
      }
      for (data_size_t i = start; i < end; ++i) {
        score[i] += static_cast<double>(leaf_value_[GetLeafRaw(iter, i)]);
      }
    });
  }
Guolin Ke's avatar
Guolin Ke committed
113
114
}

Guolin Ke's avatar
Guolin Ke committed
115
116
117
118
119
void Tree::AddPredictionToScore(const Dataset* data,
  const data_size_t* used_data_indices,
  data_size_t num_data, double* score) const {
  if (data->num_features() > num_leaves_ - 1) {
    Threading::For<data_size_t>(0, num_data,
Guolin Ke's avatar
Guolin Ke committed
120
      [this, data, used_data_indices, score](int, data_size_t start, data_size_t end) {
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
      std::vector<std::unique_ptr<BinIterator>> iter(num_leaves_ - 1);
      for (int i = 0; i < num_leaves_ - 1; ++i) {
        const int fidx = split_feature_inner[i];
        iter[i].reset(data->FeatureIterator(fidx));
        iter[i]->Reset(used_data_indices[start]);
      }
      for (data_size_t i = start; i < end; ++i) {
        score[used_data_indices[i]] += static_cast<double>(leaf_value_[GetLeaf(iter, used_data_indices[i])]);
      }
    });
  } else {
    Threading::For<data_size_t>(0, num_data,
      [this, data, used_data_indices, score](int, data_size_t start, data_size_t end) {
      std::vector<std::unique_ptr<BinIterator>> iter(data->num_features());
      for (int i = 0; i < data->num_features(); ++i) {
        iter[i].reset(data->FeatureIterator(i));
        iter[i]->Reset(used_data_indices[start]);
      }
      for (data_size_t i = start; i < end; ++i) {
        score[used_data_indices[i]] += static_cast<double>(leaf_value_[GetLeafRaw(iter, used_data_indices[i])]);
      }
    });
  }
Guolin Ke's avatar
Guolin Ke committed
144
145
146
}

std::string Tree::ToString() {
147
148
149
  std::stringstream str_buf;
  str_buf << "num_leaves=" << num_leaves_ << std::endl;
  str_buf << "split_feature="
Guolin Ke's avatar
Guolin Ke committed
150
    << Common::ArrayToString<int>(split_feature_, num_leaves_ - 1, ' ') << std::endl;
151
  str_buf << "split_gain="
Guolin Ke's avatar
Guolin Ke committed
152
    << Common::ArrayToString<double>(split_gain_, num_leaves_ - 1, ' ') << std::endl;
153
  str_buf << "threshold="
Guolin Ke's avatar
Guolin Ke committed
154
    << Common::ArrayToString<double>(threshold_, num_leaves_ - 1, ' ') << std::endl;
155
  str_buf << "left_child="
Guolin Ke's avatar
Guolin Ke committed
156
    << Common::ArrayToString<int>(left_child_, num_leaves_ - 1, ' ') << std::endl;
157
  str_buf << "right_child="
Guolin Ke's avatar
Guolin Ke committed
158
    << Common::ArrayToString<int>(right_child_, num_leaves_ - 1, ' ') << std::endl;
159
  str_buf << "leaf_parent="
Guolin Ke's avatar
Guolin Ke committed
160
    << Common::ArrayToString<int>(leaf_parent_, num_leaves_, ' ') << std::endl;
161
  str_buf << "leaf_value="
Guolin Ke's avatar
Guolin Ke committed
162
    << Common::ArrayToString<double>(leaf_value_, num_leaves_, ' ') << std::endl;
163
  str_buf << "leaf_count="
Guolin Ke's avatar
Guolin Ke committed
164
    << Common::ArrayToString<data_size_t>(leaf_count_, num_leaves_, ' ') << std::endl;
165
  str_buf << "internal_value="
Guolin Ke's avatar
Guolin Ke committed
166
    << Common::ArrayToString<double>(internal_value_, num_leaves_ - 1, ' ') << std::endl;
167
  str_buf << "internal_count="
Guolin Ke's avatar
Guolin Ke committed
168
    << Common::ArrayToString<data_size_t>(internal_count_, num_leaves_ - 1, ' ') << std::endl;
Guolin Ke's avatar
Guolin Ke committed
169
  str_buf << "shrinkage=" << shrinkage_ << std::endl;
170
171
  str_buf << std::endl;
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
172
173
}

wxchan's avatar
wxchan committed
174
std::string Tree::ToJSON() {
175
  std::stringstream str_buf;
176
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
177
  str_buf << "\"num_leaves\":" << num_leaves_ << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
178
  str_buf << "\"shrinkage\":" << shrinkage_ << "," << std::endl;
179
  str_buf << "\"tree_structure\":" << NodeToJSON(0) << std::endl;
wxchan's avatar
wxchan committed
180

181
  return str_buf.str();
wxchan's avatar
wxchan committed
182
183
184
}

std::string Tree::NodeToJSON(int index) {
185
  std::stringstream str_buf;
186
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
wxchan's avatar
wxchan committed
187
188
  if (index >= 0) {
    // non-leaf
189
190
    str_buf << "{" << std::endl;
    str_buf << "\"split_index\":" << index << "," << std::endl;
Guolin Ke's avatar
Guolin Ke committed
191
    str_buf << "\"split_feature\":" << split_feature_[index] << "," << std::endl;
192
193
194
195
196
197
198
    str_buf << "\"split_gain\":" << split_gain_[index] << "," << std::endl;
    str_buf << "\"threshold\":" << threshold_[index] << "," << std::endl;
    str_buf << "\"internal_value\":" << internal_value_[index] << "," << std::endl;
    str_buf << "\"internal_count\":" << internal_count_[index] << "," << std::endl;
    str_buf << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << std::endl;
    str_buf << "\"right_child\":" << NodeToJSON(right_child_[index]) << std::endl;
    str_buf << "}";
wxchan's avatar
wxchan committed
199
200
201
  } else {
    // leaf
    index = ~index;
202
203
204
205
206
207
    str_buf << "{" << std::endl;
    str_buf << "\"leaf_index\":" << index << "," << std::endl;
    str_buf << "\"leaf_parent\":" << leaf_parent_[index] << "," << std::endl;
    str_buf << "\"leaf_value\":" << leaf_value_[index] << "," << std::endl;
    str_buf << "\"leaf_count\":" << leaf_count_[index] << std::endl;
    str_buf << "}";
wxchan's avatar
wxchan committed
208
209
  }

210
  return str_buf.str();
wxchan's avatar
wxchan committed
211
212
}

Guolin Ke's avatar
Guolin Ke committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
Tree::Tree(const std::string& str) {
  std::vector<std::string> lines = Common::Split(str.c_str(), '\n');
  std::unordered_map<std::string, std::string> key_vals;
  for (const std::string& line : lines) {
    std::vector<std::string> tmp_strs = Common::Split(line.c_str(), '=');
    if (tmp_strs.size() == 2) {
      std::string key = Common::Trim(tmp_strs[0]);
      std::string val = Common::Trim(tmp_strs[1]);
      if (key.size() > 0 && val.size() > 0) {
        key_vals[key] = val;
      }
    }
  }
  if (key_vals.count("num_leaves") <= 0 || key_vals.count("split_feature") <= 0
    || key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0
    || key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0
229
    || key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0
Guolin Ke's avatar
Guolin Ke committed
230
    || key_vals.count("internal_value") <= 0 || key_vals.count("internal_count") <= 0
Guolin Ke's avatar
Guolin Ke committed
231
    || key_vals.count("leaf_count") <= 0 || key_vals.count("shrinkage") <= 0
Guolin Ke's avatar
Guolin Ke committed
232
    ) {
233
    Log::Fatal("Tree model string format error");
Guolin Ke's avatar
Guolin Ke committed
234
235
236
237
  }

  Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);

Guolin Ke's avatar
Guolin Ke committed
238
239
  left_child_ = Common::StringToArray<int>(key_vals["left_child"], ' ', num_leaves_ - 1);
  right_child_ = Common::StringToArray<int>(key_vals["right_child"], ' ', num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
240
  split_feature_ = Common::StringToArray<int>(key_vals["split_feature"], ' ', num_leaves_ - 1);
Guolin Ke's avatar
Guolin Ke committed
241
242
243
244
245
246
247
248
  threshold_ = Common::StringToArray<double>(key_vals["threshold"], ' ', num_leaves_ - 1);
  split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
  internal_count_ = Common::StringToArray<data_size_t>(key_vals["internal_count"], ' ', num_leaves_ - 1);
  internal_value_ = Common::StringToArray<double>(key_vals["internal_value"], ' ', num_leaves_ - 1);

  leaf_count_ = Common::StringToArray<data_size_t>(key_vals["leaf_count"], ' ', num_leaves_);
  leaf_parent_ = Common::StringToArray<int>(key_vals["leaf_parent"], ' ', num_leaves_);
  leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
Guolin Ke's avatar
Guolin Ke committed
249
  Common::Atof(key_vals["shrinkage"].c_str(), &shrinkage_);
Guolin Ke's avatar
Guolin Ke committed
250
251
252
}

}  // namespace LightGBM