serial_tree_learner.cpp 17.5 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include "serial_tree_learner.h"

#include <LightGBM/utils/array_args.h>

#include <algorithm>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
Guolin Ke's avatar
Guolin Ke committed
11
  :tree_config_(tree_config){
Guolin Ke's avatar
Guolin Ke committed
12
  random_ = Random(tree_config_->feature_fraction_seed);
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
17
#pragma omp parallel
#pragma omp master
  {
    num_threads_ = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
}

SerialTreeLearner::~SerialTreeLearner() {
}

void SerialTreeLearner::Init(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();
27
28
  int max_cache_size = 0;
  // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
29
30
  if (tree_config_->histogram_pool_size <= 0) {
    max_cache_size = tree_config_->num_leaves;
31
32
33
  } else {
    size_t total_histogram_size = 0;
    for (int i = 0; i < train_data_->num_features(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
34
      total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
35
    }
Guolin Ke's avatar
Guolin Ke committed
36
    max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
37
38
  }
  // at least need 2 leaves
Guolin Ke's avatar
Guolin Ke committed
39
  max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
40
  max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
41

Guolin Ke's avatar
Guolin Ke committed
42
  histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
43
  // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
44
  best_split_per_leaf_.resize(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
45
  
Guolin Ke's avatar
Guolin Ke committed
46
  // get ordered bin
Guolin Ke's avatar
Guolin Ke committed
47
  train_data_->CreateOrderedBins(&ordered_bins_);
Guolin Ke's avatar
Guolin Ke committed
48
49

  // check existing for ordered bin
Guolin Ke's avatar
Guolin Ke committed
50
  for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
wxchan's avatar
wxchan committed
56
  // initialize splits for leaf
Guolin Ke's avatar
Guolin Ke committed
57
58
  smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
  larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
Guolin Ke's avatar
Guolin Ke committed
59
60

  // initialize data partition
Guolin Ke's avatar
Guolin Ke committed
61
  data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
62
  is_feature_used_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
63
  // initialize ordered gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
64
65
66
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
Guolin Ke's avatar
Guolin Ke committed
67
  if (has_ordered_bin_) {
Guolin Ke's avatar
Guolin Ke committed
68
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
69
    std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
    order_bin_indices_.clear();
    for (int i = 0; i < static_cast<int>(ordered_bins_.size()); i++) {
      if (ordered_bins_[i] != nullptr) {
        order_bin_indices_.push_back(i);
      }
    }
Guolin Ke's avatar
Guolin Ke committed
76
  }
Guolin Ke's avatar
Guolin Ke committed
77
  Log::Info("Number of data: %d, number of used features: %d", num_data_, num_features_);
Guolin Ke's avatar
Guolin Ke committed
78
79
}

Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
85
void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();

  // get ordered bin
Guolin Ke's avatar
Guolin Ke committed
86
87
  train_data_->CreateOrderedBins(&ordered_bins_);

Guolin Ke's avatar
Guolin Ke committed
88
89
  has_ordered_bin_ = false;
  // check existing for ordered bin
Guolin Ke's avatar
Guolin Ke committed
90
  for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
  // initialize splits for leaf
  smaller_leaf_splits_->ResetNumData(num_data_);
  larger_leaf_splits_->ResetNumData(num_data_);

  // initialize data partition
  data_partition_->ResetNumData(num_data_);

  is_feature_used_.resize(num_features_);

  // initialize ordered gradients and hessians
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
  if (has_ordered_bin_) {
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
111
    std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
116
117
    order_bin_indices_.clear();
    for (int i = 0; i < static_cast<int>(ordered_bins_.size()); i++) {
      if (ordered_bins_[i] != nullptr) {
        order_bin_indices_.push_back(i);
      }
    }
Guolin Ke's avatar
Guolin Ke committed
118
119
120
  }

}
Guolin Ke's avatar
Guolin Ke committed
121

Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
127
128
129
130
131
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
  if (tree_config_->num_leaves != tree_config->num_leaves) {
    tree_config_ = tree_config;
    int max_cache_size = 0;
    // Get the max size of pool
    if (tree_config->histogram_pool_size <= 0) {
      max_cache_size = tree_config_->num_leaves;
    } else {
      size_t total_histogram_size = 0;
      for (int i = 0; i < train_data_->num_features(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
132
        total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
137
138
      }
      max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
    }
    // at least need 2 leaves
    max_cache_size = std::max(2, max_cache_size);
    max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
139
    histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
140
141
142

    // push split information for all leaves
    best_split_per_leaf_.resize(tree_config_->num_leaves);
143
    data_partition_->ResetLeaves(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
  } else {
    tree_config_ = tree_config;
  }

Guolin Ke's avatar
Guolin Ke committed
148
  histogram_pool_.ResetConfig(tree_config_);
Guolin Ke's avatar
Guolin Ke committed
149
150
}

Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
  gradients_ = gradients;
  hessians_ = hessians;
  // some initial works before training
  BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
156
  auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
157
  // save pointer to last trained tree
Guolin Ke's avatar
Guolin Ke committed
158
  last_trained_tree_ = tree.get();
Guolin Ke's avatar
Guolin Ke committed
159
160
  // root leaf
  int left_leaf = 0;
161
  int cur_depth = 1;
Guolin Ke's avatar
Guolin Ke committed
162
163
  // only root leaf can be splitted on first time
  int right_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
164
  for (int split = 0; split < tree_config_->num_leaves - 1; ++split) {
Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
170
171
172
173
174
175
176
177
    // some initial works before finding best split
    if (BeforeFindBestSplit(left_leaf, right_leaf)) {
      // find best threshold for every feature
      FindBestThresholds();
      // find best split from all features
      FindBestSplitsForLeaves();
    }
    // Get a leaf with max split gain
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    // Get split information for best leaf
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
178
      Log::Info("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain);
Guolin Ke's avatar
Guolin Ke committed
179
180
181
      break;
    }
    // split tree with best leaf
Guolin Ke's avatar
Guolin Ke committed
182
    Split(tree.get(), best_leaf, &left_leaf, &right_leaf);
183
    cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
Guolin Ke's avatar
Guolin Ke committed
184
  }
185
  Log::Info("Trained a tree with leaves=%d and max_depth=%d", tree->num_leaves(), cur_depth);
Guolin Ke's avatar
Guolin Ke committed
186
  return tree.release();
Guolin Ke's avatar
Guolin Ke committed
187
188
189
}

void SerialTreeLearner::BeforeTrain() {
Guolin Ke's avatar
Guolin Ke committed
190

191
192
  // reset histogram pool
  histogram_pool_.ResetMap();
Guolin Ke's avatar
Guolin Ke committed
193
  int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

  if (used_feature_cnt < num_features_) {
    // initialize used features
    std::memset(is_feature_used_.data(), 0, sizeof(int8_t) * num_features_);
    // Get used feature at current tree
    auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < static_cast<int>(used_feature_indices.size()); ++i) {
      is_feature_used_[used_feature_indices[i]] = 1;
    }
  } else {
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_features_; ++i) {
      is_feature_used_[i] = 1;
    }
Guolin Ke's avatar
Guolin Ke committed
209
  }
210

Guolin Ke's avatar
Guolin Ke committed
211
212
213
214
  // initialize data partition
  data_partition_->Init();

  // reset the splits for leaves
Guolin Ke's avatar
Guolin Ke committed
215
  for (int i = 0; i < tree_config_->num_leaves; ++i) {
Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
220
221
222
    best_split_per_leaf_[i].Reset();
  }

  // Sumup for root
  if (data_partition_->leaf_count(0) == num_data_) {
    // use all data
    smaller_leaf_splits_->Init(gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
223

Guolin Ke's avatar
Guolin Ke committed
224
225
  } else {
    // use bagging, only use part of data
Guolin Ke's avatar
Guolin Ke committed
226
    smaller_leaf_splits_->Init(0, data_partition_.get(), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
227
228
229
230
231
232
233
234
  }

  larger_leaf_splits_->Init();

  // if has ordered bin, need to initialize the ordered bin
  if (has_ordered_bin_) {
    if (data_partition_->leaf_count(0) == num_data_) {
      // use all data, pass nullptr
Guolin Ke's avatar
Guolin Ke committed
235
236
237
      #pragma omp parallel for schedule(static)
      for (int i = 0; i < static_cast<int>(order_bin_indices_.size()); ++i) {
        ordered_bins_[order_bin_indices_[i]]->Init(nullptr, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
238
239
240
241
242
243
244
245
246
247
248
249
250
      }
    } else {
      // bagging, only use part of data

      // mark used data
      const data_size_t* indices = data_partition_->indices();
      data_size_t begin = data_partition_->leaf_begin(0);
      data_size_t end = begin + data_partition_->leaf_count(0);
      #pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 1;
      }
      // initialize ordered bin
Guolin Ke's avatar
Guolin Ke committed
251
252
253
      #pragma omp parallel for schedule(static)
      for (int i = 0; i < static_cast<int>(order_bin_indices_.size()); ++i) {
        ordered_bins_[order_bin_indices_[i]]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
254
      }
Guolin Ke's avatar
Guolin Ke committed
255
256
257
258
#pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 0;
      }
Guolin Ke's avatar
Guolin Ke committed
259
260
261
262
263
    }
  }
}

bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
Guolin Ke's avatar
Guolin Ke committed
264
  // check depth of current leaf
Guolin Ke's avatar
Guolin Ke committed
265
  if (tree_config_->max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
266
    // only need to check left leaf, since right leaf is in same level of left leaf
Guolin Ke's avatar
Guolin Ke committed
267
    if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
Guolin Ke's avatar
Guolin Ke committed
268
269
270
271
272
273
274
      best_split_per_leaf_[left_leaf].gain = kMinScore;
      if (right_leaf >= 0) {
        best_split_per_leaf_[right_leaf].gain = kMinScore;
      }
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
275
276
277
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // no enough data to continue
Guolin Ke's avatar
Guolin Ke committed
278
279
  if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
    && num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
Guolin Ke's avatar
Guolin Ke committed
280
281
282
283
284
285
    best_split_per_leaf_[left_leaf].gain = kMinScore;
    if (right_leaf >= 0) {
      best_split_per_leaf_[right_leaf].gain = kMinScore;
    }
    return false;
  }
286
  parent_leaf_histogram_array_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
287
288
  // only have root
  if (right_leaf < 0) {
289
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
290
291
    larger_leaf_histogram_array_ = nullptr;
  } else if (num_data_in_left_child < num_data_in_right_child) {
Hui Xue's avatar
Hui Xue committed
292
    // put parent(left) leaf's histograms into larger leaf's histograms
293
294
295
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Move(left_leaf, right_leaf);
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
296
  } else {
Hui Xue's avatar
Hui Xue committed
297
    // put parent(left) leaf's histograms to larger leaf's histograms
298
299
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Get(right_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
300
301
302
303
304
  }
  // split for the ordered bin
  if (has_ordered_bin_ && right_leaf >= 0) {
    // mark data that at left-leaf
    const data_size_t* indices = data_partition_->indices();
Guolin Ke's avatar
Guolin Ke committed
305
306
307
    const auto left_cnt = data_partition_->leaf_count(left_leaf);
    const auto right_cnt = data_partition_->leaf_count(right_leaf);
    char mark = 1;
Guolin Ke's avatar
Guolin Ke committed
308
    data_size_t begin = data_partition_->leaf_begin(left_leaf);
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
314
    data_size_t end = begin + left_cnt;
    if (left_cnt > right_cnt) {
      begin = data_partition_->leaf_begin(right_leaf);
      end = begin + right_cnt;
      mark = 0;
    }
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
319
320
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 1;
    }
    // split the ordered bin
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
321
322
    for (int i = 0; i < static_cast<int>(order_bin_indices_.size()); ++i) {
      ordered_bins_[order_bin_indices_[i]]->Split(left_leaf, right_leaf, is_data_in_leaf_.data(), mark);
Guolin Ke's avatar
Guolin Ke committed
323
    }
Guolin Ke's avatar
Guolin Ke committed
324
325
326
327
#pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 0;
    }
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
332
  }
  return true;
}

void SerialTreeLearner::FindBestThresholds() {
Guolin Ke's avatar
Guolin Ke committed
333
  std::vector<int8_t> is_feature_used(num_features_, 0);
Guolin Ke's avatar
Guolin Ke committed
334
#pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
335
336
337
338
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
    if (!is_feature_used_[feature_index]) continue;
    if (parent_leaf_histogram_array_ != nullptr 
        && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
Guolin Ke's avatar
Guolin Ke committed
339
340
341
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }
Guolin Ke's avatar
Guolin Ke committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    is_feature_used[feature_index] = 1;
  }
  bool use_subtract = true;
  if (parent_leaf_histogram_array_ == nullptr) {
    use_subtract = false;
  }
  // construct smaller leaf
  HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
  train_data_->ConstructHistograms(is_feature_used,
    smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
    smaller_leaf_splits_->LeafIndex(),
    ordered_bins_, gradients_, hessians_,
    ordered_gradients_.data(), ordered_hessians_.data(),
    ptr_smaller_leaf_hist_data);

  if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
    // construct larger leaf
    HistogramBinEntry* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - 1;
    train_data_->ConstructHistograms(is_feature_used,
      larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
      larger_leaf_splits_->LeafIndex(),
      ordered_bins_, gradients_, hessians_,
      ordered_gradients_.data(), ordered_hessians_.data(),
      ptr_larger_leaf_hist_data);
  }
  std::vector<SplitInfo> smaller_best(num_threads_);
  std::vector<SplitInfo> larger_best(num_threads_);
  // find splits
Guolin Ke's avatar
Guolin Ke committed
370
  #pragma omp parallel for schedule(static)
Guolin Ke's avatar
Guolin Ke committed
371
372
373
374
375
376
377
378
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
    if (!is_feature_used[feature_index]) { continue; }
    const int tid = omp_get_thread_num();
    SplitInfo smaller_split;
    train_data_->FixHistogram(feature_index, 
      smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
      smaller_leaf_splits_->num_data_in_leaf(),
      smaller_leaf_histogram_array_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
379

Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
    smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
      smaller_leaf_splits_->sum_gradients(),
      smaller_leaf_splits_->sum_hessians(),
      smaller_leaf_splits_->num_data_in_leaf(),
Guolin Ke's avatar
Guolin Ke committed
384
385
386
387
      &smaller_split);
    if (smaller_split.gain > smaller_best[tid].gain) {
      smaller_best[tid] = smaller_split;
    }
Guolin Ke's avatar
Guolin Ke committed
388
    // only has root leaf
Guolin Ke's avatar
Guolin Ke committed
389
    if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
390

Guolin Ke's avatar
Guolin Ke committed
391
    if (use_subtract) {
392
393
      larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
    } else {
Guolin Ke's avatar
Guolin Ke committed
394
395
396
      train_data_->FixHistogram(feature_index, larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians(),
        larger_leaf_splits_->num_data_in_leaf(),
        larger_leaf_histogram_array_[feature_index].RawData());
397
    }
Guolin Ke's avatar
Guolin Ke committed
398
    SplitInfo larger_split;
Guolin Ke's avatar
Guolin Ke committed
399
    // find best threshold for larger child
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
    larger_leaf_histogram_array_[feature_index].FindBestThreshold(
      larger_leaf_splits_->sum_gradients(),
      larger_leaf_splits_->sum_hessians(),
      larger_leaf_splits_->num_data_in_leaf(),
Guolin Ke's avatar
Guolin Ke committed
404
405
406
407
      &larger_split);
    if (larger_split.gain > larger_best[tid].gain) {
      larger_best[tid] = larger_split;
    }
Guolin Ke's avatar
Guolin Ke committed
408
  }
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
415
416
417
418
419
420
421
422

  auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
  int leaf = smaller_leaf_splits_->LeafIndex();
  best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];

  if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->LeafIndex() >= 0) {
    leaf = larger_leaf_splits_->LeafIndex();
    auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
    best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
  }
}

void SerialTreeLearner::FindBestSplitsForLeaves() {

Guolin Ke's avatar
Guolin Ke committed
423
424
425
426
427
428
429
430
}


void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
  const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
  // left = parent
  *left_leaf = best_Leaf;
  // split tree, will return right leaf
Guolin Ke's avatar
Guolin Ke committed
431
432
  *right_leaf = tree->Split(best_Leaf, best_split_info.feature, 
    best_split_info.threshold,
Guolin Ke's avatar
Guolin Ke committed
433
434
    train_data_->RealFeatureIndex(best_split_info.feature),
    train_data_->RealThreshold(best_split_info.feature, best_split_info.threshold),
435
436
    static_cast<double>(best_split_info.left_output),
    static_cast<double>(best_split_info.right_output),
Guolin Ke's avatar
Guolin Ke committed
437
438
    static_cast<data_size_t>(best_split_info.left_count),
    static_cast<data_size_t>(best_split_info.right_count),
439
    static_cast<double>(best_split_info.gain));
Guolin Ke's avatar
Guolin Ke committed
440
  // split data partition
Guolin Ke's avatar
Guolin Ke committed
441
  data_partition_->Split(best_Leaf, train_data_, best_split_info.feature, 
Guolin Ke's avatar
Guolin Ke committed
442
443
444
445
                         best_split_info.threshold, *right_leaf);

  // init the leaves that used on next iteration
  if (best_split_info.left_count < best_split_info.right_count) {
Guolin Ke's avatar
Guolin Ke committed
446
    smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
447
448
                               best_split_info.left_sum_gradient,
                               best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
449
    larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
450
451
452
                               best_split_info.right_sum_gradient,
                               best_split_info.right_sum_hessian);
  } else {
Guolin Ke's avatar
Guolin Ke committed
453
454
    smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
    larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
455
456
457
  }
}

Guolin Ke's avatar
Guolin Ke committed
458

Guolin Ke's avatar
Guolin Ke committed
459
}  // namespace LightGBM