serial_tree_learner.cpp 17.3 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include "serial_tree_learner.h"

#include <LightGBM/utils/array_args.h>

#include <algorithm>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
Guolin Ke's avatar
Guolin Ke committed
11
  :tree_config_(tree_config){
Guolin Ke's avatar
Guolin Ke committed
12
  random_ = Random(tree_config_->feature_fraction_seed);
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
17
#pragma omp parallel
#pragma omp master
  {
    num_threads_ = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
}

SerialTreeLearner::~SerialTreeLearner() {
}

void SerialTreeLearner::Init(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();
27
28
  int max_cache_size = 0;
  // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
29
30
  if (tree_config_->histogram_pool_size <= 0) {
    max_cache_size = tree_config_->num_leaves;
31
32
33
  } else {
    size_t total_histogram_size = 0;
    for (int i = 0; i < train_data_->num_features(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
34
      total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
35
    }
Guolin Ke's avatar
Guolin Ke committed
36
    max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
37
38
  }
  // at least need 2 leaves
Guolin Ke's avatar
Guolin Ke committed
39
  max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
40
  max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
41

Guolin Ke's avatar
Guolin Ke committed
42
  histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
43
  // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
44
  best_split_per_leaf_.resize(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
45
  
Guolin Ke's avatar
Guolin Ke committed
46
  // get ordered bin
Guolin Ke's avatar
Guolin Ke committed
47
  train_data_->CreateOrderedBins(&ordered_bins_);
Guolin Ke's avatar
Guolin Ke committed
48
49

  // check existing for ordered bin
Guolin Ke's avatar
Guolin Ke committed
50
  for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
wxchan's avatar
wxchan committed
56
  // initialize splits for leaf
Guolin Ke's avatar
Guolin Ke committed
57
58
  smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
  larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
Guolin Ke's avatar
Guolin Ke committed
59
60

  // initialize data partition
Guolin Ke's avatar
Guolin Ke committed
61
  data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
62
  is_feature_used_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
63
  // initialize ordered gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
64
65
66
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
Guolin Ke's avatar
Guolin Ke committed
67
  if (has_ordered_bin_) {
Guolin Ke's avatar
Guolin Ke committed
68
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
69
    std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
70
  }
Guolin Ke's avatar
Guolin Ke committed
71
  Log::Info("Number of data: %d, number of used features: %d", num_data_, num_features_);
Guolin Ke's avatar
Guolin Ke committed
72
73
}

Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
79
void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();

  // get ordered bin
Guolin Ke's avatar
Guolin Ke committed
80
81
  train_data_->CreateOrderedBins(&ordered_bins_);

Guolin Ke's avatar
Guolin Ke committed
82
83
  has_ordered_bin_ = false;
  // check existing for ordered bin
Guolin Ke's avatar
Guolin Ke committed
84
  for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
  // initialize splits for leaf
  smaller_leaf_splits_->ResetNumData(num_data_);
  larger_leaf_splits_->ResetNumData(num_data_);

  // initialize data partition
  data_partition_->ResetNumData(num_data_);

  is_feature_used_.resize(num_features_);

  // initialize ordered gradients and hessians
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
  if (has_ordered_bin_) {
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
105
    std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
106
107
108
  }

}
Guolin Ke's avatar
Guolin Ke committed
109

Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
116
117
118
119
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
  if (tree_config_->num_leaves != tree_config->num_leaves) {
    tree_config_ = tree_config;
    int max_cache_size = 0;
    // Get the max size of pool
    if (tree_config->histogram_pool_size <= 0) {
      max_cache_size = tree_config_->num_leaves;
    } else {
      size_t total_histogram_size = 0;
      for (int i = 0; i < train_data_->num_features(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
120
        total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
126
      }
      max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
    }
    // at least need 2 leaves
    max_cache_size = std::max(2, max_cache_size);
    max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
127
    histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    // push split information for all leaves
    best_split_per_leaf_.resize(tree_config_->num_leaves);
131
    data_partition_->ResetLeaves(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
  } else {
    tree_config_ = tree_config;
  }

Guolin Ke's avatar
Guolin Ke committed
136
  histogram_pool_.ResetConfig(tree_config_);
Guolin Ke's avatar
Guolin Ke committed
137
138
}

Guolin Ke's avatar
Guolin Ke committed
139
140
141
142
143
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
  gradients_ = gradients;
  hessians_ = hessians;
  // some initial works before training
  BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
144
  auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
145
  // save pointer to last trained tree
Guolin Ke's avatar
Guolin Ke committed
146
  last_trained_tree_ = tree.get();
Guolin Ke's avatar
Guolin Ke committed
147
148
  // root leaf
  int left_leaf = 0;
149
  int cur_depth = 1;
Guolin Ke's avatar
Guolin Ke committed
150
151
  // only root leaf can be splitted on first time
  int right_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
152
  for (int split = 0; split < tree_config_->num_leaves - 1; ++split) {
Guolin Ke's avatar
Guolin Ke committed
153
154
155
156
157
158
159
160
161
162
163
164
165
    // some initial works before finding best split
    if (BeforeFindBestSplit(left_leaf, right_leaf)) {
      // find best threshold for every feature
      FindBestThresholds();
      // find best split from all features
      FindBestSplitsForLeaves();
    }
    // Get a leaf with max split gain
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    // Get split information for best leaf
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
166
      Log::Info("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain);
Guolin Ke's avatar
Guolin Ke committed
167
168
169
      break;
    }
    // split tree with best leaf
Guolin Ke's avatar
Guolin Ke committed
170
    Split(tree.get(), best_leaf, &left_leaf, &right_leaf);
171
    cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
Guolin Ke's avatar
Guolin Ke committed
172
  }
173
  Log::Info("Trained a tree with leaves=%d and max_depth=%d", tree->num_leaves(), cur_depth);
Guolin Ke's avatar
Guolin Ke committed
174
  return tree.release();
Guolin Ke's avatar
Guolin Ke committed
175
176
177
}

void SerialTreeLearner::BeforeTrain() {
Guolin Ke's avatar
Guolin Ke committed
178

179
180
  // reset histogram pool
  histogram_pool_.ResetMap();
Guolin Ke's avatar
Guolin Ke committed
181
  int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196

  if (used_feature_cnt < num_features_) {
    // initialize used features
    std::memset(is_feature_used_.data(), 0, sizeof(int8_t) * num_features_);
    // Get used feature at current tree
    auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < static_cast<int>(used_feature_indices.size()); ++i) {
      is_feature_used_[used_feature_indices[i]] = 1;
    }
  } else {
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_features_; ++i) {
      is_feature_used_[i] = 1;
    }
Guolin Ke's avatar
Guolin Ke committed
197
  }
198

Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
  // initialize data partition
  data_partition_->Init();

  // reset the splits for leaves
Guolin Ke's avatar
Guolin Ke committed
203
  for (int i = 0; i < tree_config_->num_leaves; ++i) {
Guolin Ke's avatar
Guolin Ke committed
204
205
206
207
208
209
210
    best_split_per_leaf_[i].Reset();
  }

  // Sumup for root
  if (data_partition_->leaf_count(0) == num_data_) {
    // use all data
    smaller_leaf_splits_->Init(gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
211

Guolin Ke's avatar
Guolin Ke committed
212
213
  } else {
    // use bagging, only use part of data
Guolin Ke's avatar
Guolin Ke committed
214
    smaller_leaf_splits_->Init(0, data_partition_.get(), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
215
216
217
218
219
220
221
222
223
  }

  larger_leaf_splits_->Init();

  // if has ordered bin, need to initialize the ordered bin
  if (has_ordered_bin_) {
    if (data_partition_->leaf_count(0) == num_data_) {
      // use all data, pass nullptr
      #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
      for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
        auto ptr = ordered_bins_[i].get();
        if (ptr != nullptr) {
          ptr->Init(nullptr, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
        }
      }
    } else {
      // bagging, only use part of data

      // mark used data
      const data_size_t* indices = data_partition_->indices();
      data_size_t begin = data_partition_->leaf_begin(0);
      data_size_t end = begin + data_partition_->leaf_count(0);
      #pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 1;
      }
      // initialize ordered bin
      #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
      for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
        auto ptr = ordered_bins_[i].get();
        if (ptr != nullptr) {
          ptr->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
247
248
        }
      }
Guolin Ke's avatar
Guolin Ke committed
249
250
251
252
#pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 0;
      }
Guolin Ke's avatar
Guolin Ke committed
253
254
255
256
257
    }
  }
}

bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
Guolin Ke's avatar
Guolin Ke committed
258
  // check depth of current leaf
Guolin Ke's avatar
Guolin Ke committed
259
  if (tree_config_->max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
260
    // only need to check left leaf, since right leaf is in same level of left leaf
Guolin Ke's avatar
Guolin Ke committed
261
    if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
Guolin Ke's avatar
Guolin Ke committed
262
263
264
265
266
267
268
      best_split_per_leaf_[left_leaf].gain = kMinScore;
      if (right_leaf >= 0) {
        best_split_per_leaf_[right_leaf].gain = kMinScore;
      }
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
269
270
271
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // no enough data to continue
Guolin Ke's avatar
Guolin Ke committed
272
273
  if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
    && num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
Guolin Ke's avatar
Guolin Ke committed
274
275
276
277
278
279
    best_split_per_leaf_[left_leaf].gain = kMinScore;
    if (right_leaf >= 0) {
      best_split_per_leaf_[right_leaf].gain = kMinScore;
    }
    return false;
  }
280
  parent_leaf_histogram_array_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
281
282
  // only have root
  if (right_leaf < 0) {
283
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
284
285
    larger_leaf_histogram_array_ = nullptr;
  } else if (num_data_in_left_child < num_data_in_right_child) {
Hui Xue's avatar
Hui Xue committed
286
    // put parent(left) leaf's histograms into larger leaf's histograms
287
288
289
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Move(left_leaf, right_leaf);
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
290
  } else {
Hui Xue's avatar
Hui Xue committed
291
    // put parent(left) leaf's histograms to larger leaf's histograms
292
293
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Get(right_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
298
  }
  // split for the ordered bin
  if (has_ordered_bin_ && right_leaf >= 0) {
    // mark data that at left-leaf
    const data_size_t* indices = data_partition_->indices();
Guolin Ke's avatar
Guolin Ke committed
299
300
301
    const auto left_cnt = data_partition_->leaf_count(left_leaf);
    const auto right_cnt = data_partition_->leaf_count(right_leaf);
    char mark = 1;
Guolin Ke's avatar
Guolin Ke committed
302
    data_size_t begin = data_partition_->leaf_begin(left_leaf);
Guolin Ke's avatar
Guolin Ke committed
303
304
305
306
307
308
    data_size_t end = begin + left_cnt;
    if (left_cnt > right_cnt) {
      begin = data_partition_->leaf_begin(right_leaf);
      end = begin + right_cnt;
      mark = 0;
    }
Guolin Ke's avatar
Guolin Ke committed
309
310
311
312
313
314
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 1;
    }
    // split the ordered bin
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
    for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
      auto ptr = ordered_bins_[i].get();
      if (ptr != nullptr) {
        ptr->Split(left_leaf, right_leaf, is_data_in_leaf_.data(), mark);
Guolin Ke's avatar
Guolin Ke committed
319
320
      }
    }
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
#pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 0;
    }
Guolin Ke's avatar
Guolin Ke committed
325
326
327
328
329
  }
  return true;
}

void SerialTreeLearner::FindBestThresholds() {
Guolin Ke's avatar
Guolin Ke committed
330
331
332
333
334
335
  std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(guided)
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
    if (!is_feature_used_[feature_index]) continue;
    if (parent_leaf_histogram_array_ != nullptr 
        && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
Guolin Ke's avatar
Guolin Ke committed
336
337
338
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }
Guolin Ke's avatar
Guolin Ke committed
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    is_feature_used[feature_index] = 1;
  }
  bool use_subtract = true;
  if (parent_leaf_histogram_array_ == nullptr) {
    use_subtract = false;
  }
  // construct smaller leaf
  HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
  train_data_->ConstructHistograms(is_feature_used,
    smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
    smaller_leaf_splits_->LeafIndex(),
    ordered_bins_, gradients_, hessians_,
    ordered_gradients_.data(), ordered_hessians_.data(),
    ptr_smaller_leaf_hist_data);

  if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
    // construct larger leaf
    HistogramBinEntry* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - 1;
    train_data_->ConstructHistograms(is_feature_used,
      larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
      larger_leaf_splits_->LeafIndex(),
      ordered_bins_, gradients_, hessians_,
      ordered_gradients_.data(), ordered_hessians_.data(),
      ptr_larger_leaf_hist_data);
  }
  std::vector<SplitInfo> smaller_best(num_threads_);
  std::vector<SplitInfo> larger_best(num_threads_);
  // find splits
  #pragma omp parallel for schedule(guided)
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
    if (!is_feature_used[feature_index]) { continue; }
    const int tid = omp_get_thread_num();
    SplitInfo smaller_split;
    train_data_->FixHistogram(feature_index, 
      smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
      smaller_leaf_splits_->num_data_in_leaf(),
      smaller_leaf_histogram_array_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
376

Guolin Ke's avatar
Guolin Ke committed
377
378
379
380
    smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
      smaller_leaf_splits_->sum_gradients(),
      smaller_leaf_splits_->sum_hessians(),
      smaller_leaf_splits_->num_data_in_leaf(),
Guolin Ke's avatar
Guolin Ke committed
381
382
383
384
      &smaller_split);
    if (smaller_split.gain > smaller_best[tid].gain) {
      smaller_best[tid] = smaller_split;
    }
Guolin Ke's avatar
Guolin Ke committed
385
    // only has root leaf
Guolin Ke's avatar
Guolin Ke committed
386
    if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
387

Guolin Ke's avatar
Guolin Ke committed
388
    if (use_subtract) {
389
390
      larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
    } else {
Guolin Ke's avatar
Guolin Ke committed
391
392
393
      train_data_->FixHistogram(feature_index, larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians(),
        larger_leaf_splits_->num_data_in_leaf(),
        larger_leaf_histogram_array_[feature_index].RawData());
394
    }
Guolin Ke's avatar
Guolin Ke committed
395
    SplitInfo larger_split;
Guolin Ke's avatar
Guolin Ke committed
396
    // find best threshold for larger child
Guolin Ke's avatar
Guolin Ke committed
397
398
399
400
    larger_leaf_histogram_array_[feature_index].FindBestThreshold(
      larger_leaf_splits_->sum_gradients(),
      larger_leaf_splits_->sum_hessians(),
      larger_leaf_splits_->num_data_in_leaf(),
Guolin Ke's avatar
Guolin Ke committed
401
402
403
404
      &larger_split);
    if (larger_split.gain > larger_best[tid].gain) {
      larger_best[tid] = larger_split;
    }
Guolin Ke's avatar
Guolin Ke committed
405
  }
Guolin Ke's avatar
Guolin Ke committed
406
407
408
409
410
411
412
413
414
415
416
417
418
419

  auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
  int leaf = smaller_leaf_splits_->LeafIndex();
  best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];

  if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->LeafIndex() >= 0) {
    leaf = larger_leaf_splits_->LeafIndex();
    auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
    best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
  }
}

void SerialTreeLearner::FindBestSplitsForLeaves() {

Guolin Ke's avatar
Guolin Ke committed
420
421
422
423
424
425
426
427
}


void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
  const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
  // left = parent
  *left_leaf = best_Leaf;
  // split tree, will return right leaf
Guolin Ke's avatar
Guolin Ke committed
428
429
  *right_leaf = tree->Split(best_Leaf, best_split_info.feature, 
    best_split_info.threshold,
Guolin Ke's avatar
Guolin Ke committed
430
431
    train_data_->RealFeatureIndex(best_split_info.feature),
    train_data_->RealThreshold(best_split_info.feature, best_split_info.threshold),
432
433
    static_cast<double>(best_split_info.left_output),
    static_cast<double>(best_split_info.right_output),
Guolin Ke's avatar
Guolin Ke committed
434
435
    static_cast<data_size_t>(best_split_info.left_count),
    static_cast<data_size_t>(best_split_info.right_count),
436
    static_cast<double>(best_split_info.gain));
Guolin Ke's avatar
Guolin Ke committed
437
  // split data partition
Guolin Ke's avatar
Guolin Ke committed
438
  data_partition_->Split(best_Leaf, train_data_, best_split_info.feature, 
Guolin Ke's avatar
Guolin Ke committed
439
440
441
442
                         best_split_info.threshold, *right_leaf);

  // init the leaves that used on next iteration
  if (best_split_info.left_count < best_split_info.right_count) {
Guolin Ke's avatar
Guolin Ke committed
443
    smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
444
445
                               best_split_info.left_sum_gradient,
                               best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
446
    larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
447
448
449
                               best_split_info.right_sum_gradient,
                               best_split_info.right_sum_hessian);
  } else {
Guolin Ke's avatar
Guolin Ke committed
450
451
    smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
    larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
452
453
454
  }
}

Guolin Ke's avatar
Guolin Ke committed
455

Guolin Ke's avatar
Guolin Ke committed
456
}  // namespace LightGBM