serial_tree_learner.cpp 17.1 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include "serial_tree_learner.h"

#include <LightGBM/utils/array_args.h>

#include <algorithm>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
Guolin Ke's avatar
Guolin Ke committed
11
  :tree_config_(tree_config){
Guolin Ke's avatar
Guolin Ke committed
12
  random_ = Random(tree_config_->feature_fraction_seed);
Guolin Ke's avatar
Guolin Ke committed
13
14
15
16
17
#pragma omp parallel
#pragma omp master
  {
    num_threads_ = omp_get_num_threads();
  }
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
}

SerialTreeLearner::~SerialTreeLearner() {
}

void SerialTreeLearner::Init(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();
27
28
  int max_cache_size = 0;
  // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
29
30
  if (tree_config_->histogram_pool_size <= 0) {
    max_cache_size = tree_config_->num_leaves;
31
32
33
  } else {
    size_t total_histogram_size = 0;
    for (int i = 0; i < train_data_->num_features(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
34
      total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
35
    }
Guolin Ke's avatar
Guolin Ke committed
36
    max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
37
38
  }
  // at least need 2 leaves
Guolin Ke's avatar
Guolin Ke committed
39
  max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
40
  max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
41

Guolin Ke's avatar
Guolin Ke committed
42
  histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
43
  // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
44
  best_split_per_leaf_.resize(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
45
  
Guolin Ke's avatar
Guolin Ke committed
46
  // get ordered bin
Guolin Ke's avatar
Guolin Ke committed
47
  train_data_->CreateOrderedBins(&ordered_bins_);
Guolin Ke's avatar
Guolin Ke committed
48
49

  // check existing for ordered bin
Guolin Ke's avatar
Guolin Ke committed
50
  for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
51
52
53
54
55
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
wxchan's avatar
wxchan committed
56
  // initialize splits for leaf
Guolin Ke's avatar
Guolin Ke committed
57
58
  smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
  larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
Guolin Ke's avatar
Guolin Ke committed
59
60

  // initialize data partition
Guolin Ke's avatar
Guolin Ke committed
61
  data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
62
  is_feature_used_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
63
  // initialize ordered gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
64
65
66
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
Guolin Ke's avatar
Guolin Ke committed
67
  if (has_ordered_bin_) {
Guolin Ke's avatar
Guolin Ke committed
68
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
69
    std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
70
  }
Guolin Ke's avatar
Guolin Ke committed
71
  Log::Info("Number of data: %d, number of used features: %d", num_data_, num_features_);
Guolin Ke's avatar
Guolin Ke committed
72
73
}

Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
78
79
void SerialTreeLearner::ResetTrainingData(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();

  // get ordered bin
Guolin Ke's avatar
Guolin Ke committed
80
81
  train_data_->CreateOrderedBins(&ordered_bins_);

Guolin Ke's avatar
Guolin Ke committed
82
83
  has_ordered_bin_ = false;
  // check existing for ordered bin
Guolin Ke's avatar
Guolin Ke committed
84
  for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
  // initialize splits for leaf
  smaller_leaf_splits_->ResetNumData(num_data_);
  larger_leaf_splits_->ResetNumData(num_data_);

  // initialize data partition
  data_partition_->ResetNumData(num_data_);

  is_feature_used_.resize(num_features_);

  // initialize ordered gradients and hessians
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
  if (has_ordered_bin_) {
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
105
    std::fill(is_data_in_leaf_.begin(), is_data_in_leaf_.end(), 0);
Guolin Ke's avatar
Guolin Ke committed
106
107
108
  }

}
Guolin Ke's avatar
Guolin Ke committed
109

Guolin Ke's avatar
Guolin Ke committed
110
111
112
113
114
115
116
117
118
119
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
  if (tree_config_->num_leaves != tree_config->num_leaves) {
    tree_config_ = tree_config;
    int max_cache_size = 0;
    // Get the max size of pool
    if (tree_config->histogram_pool_size <= 0) {
      max_cache_size = tree_config_->num_leaves;
    } else {
      size_t total_histogram_size = 0;
      for (int i = 0; i < train_data_->num_features(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
120
        total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureNumBin(i);
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
126
      }
      max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
    }
    // at least need 2 leaves
    max_cache_size = std::max(2, max_cache_size);
    max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
127
    histogram_pool_.DynamicChangeSize(train_data_, tree_config_, max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
128
129
130

    // push split information for all leaves
    best_split_per_leaf_.resize(tree_config_->num_leaves);
131
    data_partition_->ResetLeaves(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
  } else {
    tree_config_ = tree_config;
  }

Guolin Ke's avatar
Guolin Ke committed
136
  histogram_pool_.ResetConfig(tree_config_);
Guolin Ke's avatar
Guolin Ke committed
137
138
}

Guolin Ke's avatar
Guolin Ke committed
139
140
141
142
143
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
  gradients_ = gradients;
  hessians_ = hessians;
  // some initial works before training
  BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
144
  auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
145
  // save pointer to last trained tree
Guolin Ke's avatar
Guolin Ke committed
146
  last_trained_tree_ = tree.get();
Guolin Ke's avatar
Guolin Ke committed
147
148
149
150
  // root leaf
  int left_leaf = 0;
  // only root leaf can be splitted on first time
  int right_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
151
  for (int split = 0; split < tree_config_->num_leaves - 1; ++split) {
Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
156
157
158
159
160
161
162
163
164
    // some initial works before finding best split
    if (BeforeFindBestSplit(left_leaf, right_leaf)) {
      // find best threshold for every feature
      FindBestThresholds();
      // find best split from all features
      FindBestSplitsForLeaves();
    }
    // Get a leaf with max split gain
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    // Get split information for best leaf
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
165
      Log::Info("No further splits with positive gain, best gain: %f, leaves: %d",
Guolin Ke's avatar
Guolin Ke committed
166
167
168
169
                   best_leaf_SplitInfo.gain, split + 1);
      break;
    }
    // split tree with best leaf
Guolin Ke's avatar
Guolin Ke committed
170
    Split(tree.get(), best_leaf, &left_leaf, &right_leaf);
Guolin Ke's avatar
Guolin Ke committed
171
  }
Guolin Ke's avatar
Guolin Ke committed
172
  return tree.release();
Guolin Ke's avatar
Guolin Ke committed
173
174
175
}

void SerialTreeLearner::BeforeTrain() {
Guolin Ke's avatar
Guolin Ke committed
176

177
178
  // reset histogram pool
  histogram_pool_.ResetMap();
Guolin Ke's avatar
Guolin Ke committed
179
  int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

  if (used_feature_cnt < num_features_) {
    // initialize used features
    std::memset(is_feature_used_.data(), 0, sizeof(int8_t) * num_features_);
    // Get used feature at current tree
    auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < static_cast<int>(used_feature_indices.size()); ++i) {
      is_feature_used_[used_feature_indices[i]] = 1;
    }
  } else {
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_features_; ++i) {
      is_feature_used_[i] = 1;
    }
Guolin Ke's avatar
Guolin Ke committed
195
  }
196

Guolin Ke's avatar
Guolin Ke committed
197
198
199
200
  // initialize data partition
  data_partition_->Init();

  // reset the splits for leaves
Guolin Ke's avatar
Guolin Ke committed
201
  for (int i = 0; i < tree_config_->num_leaves; ++i) {
Guolin Ke's avatar
Guolin Ke committed
202
203
204
205
206
207
208
    best_split_per_leaf_[i].Reset();
  }

  // Sumup for root
  if (data_partition_->leaf_count(0) == num_data_) {
    // use all data
    smaller_leaf_splits_->Init(gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
209

Guolin Ke's avatar
Guolin Ke committed
210
211
  } else {
    // use bagging, only use part of data
Guolin Ke's avatar
Guolin Ke committed
212
    smaller_leaf_splits_->Init(0, data_partition_.get(), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
213
214
215
216
217
218
219
220
221
  }

  larger_leaf_splits_->Init();

  // if has ordered bin, need to initialize the ordered bin
  if (has_ordered_bin_) {
    if (data_partition_->leaf_count(0) == num_data_) {
      // use all data, pass nullptr
      #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
      for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
        auto ptr = ordered_bins_[i].get();
        if (ptr != nullptr) {
          ptr->Init(nullptr, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        }
      }
    } else {
      // bagging, only use part of data

      // mark used data
      const data_size_t* indices = data_partition_->indices();
      data_size_t begin = data_partition_->leaf_begin(0);
      data_size_t end = begin + data_partition_->leaf_count(0);
      #pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 1;
      }
      // initialize ordered bin
      #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
241
242
243
244
      for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
        auto ptr = ordered_bins_[i].get();
        if (ptr != nullptr) {
          ptr->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
245
246
        }
      }
Guolin Ke's avatar
Guolin Ke committed
247
248
249
250
#pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 0;
      }
Guolin Ke's avatar
Guolin Ke committed
251
252
253
254
255
    }
  }
}

bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
Guolin Ke's avatar
Guolin Ke committed
256
  // check depth of current leaf
Guolin Ke's avatar
Guolin Ke committed
257
  if (tree_config_->max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
258
    // only need to check left leaf, since right leaf is in same level of left leaf
Guolin Ke's avatar
Guolin Ke committed
259
    if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
Guolin Ke's avatar
Guolin Ke committed
260
261
262
263
264
265
266
      best_split_per_leaf_[left_leaf].gain = kMinScore;
      if (right_leaf >= 0) {
        best_split_per_leaf_[right_leaf].gain = kMinScore;
      }
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
267
268
269
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // no enough data to continue
Guolin Ke's avatar
Guolin Ke committed
270
271
  if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
    && num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
Guolin Ke's avatar
Guolin Ke committed
272
273
274
275
276
277
    best_split_per_leaf_[left_leaf].gain = kMinScore;
    if (right_leaf >= 0) {
      best_split_per_leaf_[right_leaf].gain = kMinScore;
    }
    return false;
  }
278
  parent_leaf_histogram_array_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
279
280
  // only have root
  if (right_leaf < 0) {
281
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
282
283
    larger_leaf_histogram_array_ = nullptr;
  } else if (num_data_in_left_child < num_data_in_right_child) {
Hui Xue's avatar
Hui Xue committed
284
    // put parent(left) leaf's histograms into larger leaf's histograms
285
286
287
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Move(left_leaf, right_leaf);
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
288
  } else {
Hui Xue's avatar
Hui Xue committed
289
    // put parent(left) leaf's histograms to larger leaf's histograms
290
291
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Get(right_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
292
293
294
295
296
  }
  // split for the ordered bin
  if (has_ordered_bin_ && right_leaf >= 0) {
    // mark data that at left-leaf
    const data_size_t* indices = data_partition_->indices();
Guolin Ke's avatar
Guolin Ke committed
297
298
299
    const auto left_cnt = data_partition_->leaf_count(left_leaf);
    const auto right_cnt = data_partition_->leaf_count(right_leaf);
    char mark = 1;
Guolin Ke's avatar
Guolin Ke committed
300
    data_size_t begin = data_partition_->leaf_begin(left_leaf);
Guolin Ke's avatar
Guolin Ke committed
301
302
303
304
305
306
    data_size_t end = begin + left_cnt;
    if (left_cnt > right_cnt) {
      begin = data_partition_->leaf_begin(right_leaf);
      end = begin + right_cnt;
      mark = 0;
    }
Guolin Ke's avatar
Guolin Ke committed
307
308
309
310
311
312
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 1;
    }
    // split the ordered bin
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
    for (int i = 0; i < static_cast<int>(ordered_bins_.size()); ++i) {
      auto ptr = ordered_bins_[i].get();
      if (ptr != nullptr) {
        ptr->Split(left_leaf, right_leaf, is_data_in_leaf_.data(), mark);
Guolin Ke's avatar
Guolin Ke committed
317
318
      }
    }
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
#pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 0;
    }
Guolin Ke's avatar
Guolin Ke committed
323
324
325
326
327
  }
  return true;
}

void SerialTreeLearner::FindBestThresholds() {
Guolin Ke's avatar
Guolin Ke committed
328
329
330
331
332
333
  std::vector<int8_t> is_feature_used(num_features_, 0);
#pragma omp parallel for schedule(guided)
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
    if (!is_feature_used_[feature_index]) continue;
    if (parent_leaf_histogram_array_ != nullptr 
        && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
Guolin Ke's avatar
Guolin Ke committed
334
335
336
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }
Guolin Ke's avatar
Guolin Ke committed
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
    is_feature_used[feature_index] = 1;
  }
  bool use_subtract = true;
  if (parent_leaf_histogram_array_ == nullptr) {
    use_subtract = false;
  }
  // construct smaller leaf
  HistogramBinEntry* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - 1;
  train_data_->ConstructHistograms(is_feature_used,
    smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
    smaller_leaf_splits_->LeafIndex(),
    ordered_bins_, gradients_, hessians_,
    ordered_gradients_.data(), ordered_hessians_.data(),
    ptr_smaller_leaf_hist_data);

  if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
    // construct larger leaf
    HistogramBinEntry* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - 1;
    train_data_->ConstructHistograms(is_feature_used,
      larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
      larger_leaf_splits_->LeafIndex(),
      ordered_bins_, gradients_, hessians_,
      ordered_gradients_.data(), ordered_hessians_.data(),
      ptr_larger_leaf_hist_data);
  }
  std::vector<SplitInfo> smaller_best(num_threads_);
  std::vector<SplitInfo> larger_best(num_threads_);
  // find splits
  #pragma omp parallel for schedule(guided)
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
    if (!is_feature_used[feature_index]) { continue; }
    const int tid = omp_get_thread_num();
    SplitInfo smaller_split;
    train_data_->FixHistogram(feature_index, 
      smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_hessians(),
      smaller_leaf_splits_->num_data_in_leaf(),
      smaller_leaf_histogram_array_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
374

Guolin Ke's avatar
Guolin Ke committed
375
376
377
378
    smaller_leaf_histogram_array_[feature_index].FindBestThreshold(
      smaller_leaf_splits_->sum_gradients(),
      smaller_leaf_splits_->sum_hessians(),
      smaller_leaf_splits_->num_data_in_leaf(),
Guolin Ke's avatar
Guolin Ke committed
379
380
381
382
      &smaller_split);
    if (smaller_split.gain > smaller_best[tid].gain) {
      smaller_best[tid] = smaller_split;
    }
Guolin Ke's avatar
Guolin Ke committed
383
    // only has root leaf
Guolin Ke's avatar
Guolin Ke committed
384
    if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
385

Guolin Ke's avatar
Guolin Ke committed
386
    if (use_subtract) {
387
388
      larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
    } else {
Guolin Ke's avatar
Guolin Ke committed
389
390
391
      train_data_->FixHistogram(feature_index, larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_hessians(),
        larger_leaf_splits_->num_data_in_leaf(),
        larger_leaf_histogram_array_[feature_index].RawData());
392
    }
Guolin Ke's avatar
Guolin Ke committed
393
    SplitInfo larger_split;
Guolin Ke's avatar
Guolin Ke committed
394
    // find best threshold for larger child
Guolin Ke's avatar
Guolin Ke committed
395
396
397
398
    larger_leaf_histogram_array_[feature_index].FindBestThreshold(
      larger_leaf_splits_->sum_gradients(),
      larger_leaf_splits_->sum_hessians(),
      larger_leaf_splits_->num_data_in_leaf(),
Guolin Ke's avatar
Guolin Ke committed
399
400
401
402
      &larger_split);
    if (larger_split.gain > larger_best[tid].gain) {
      larger_best[tid] = larger_split;
    }
Guolin Ke's avatar
Guolin Ke committed
403
  }
Guolin Ke's avatar
Guolin Ke committed
404
405
406
407
408
409
410
411
412
413
414
415
416
417

  auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
  int leaf = smaller_leaf_splits_->LeafIndex();
  best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];

  if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->LeafIndex() >= 0) {
    leaf = larger_leaf_splits_->LeafIndex();
    auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
    best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
  }
}

void SerialTreeLearner::FindBestSplitsForLeaves() {

Guolin Ke's avatar
Guolin Ke committed
418
419
420
421
422
423
424
425
}


void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
  const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
  // left = parent
  *left_leaf = best_Leaf;
  // split tree, will return right leaf
Guolin Ke's avatar
Guolin Ke committed
426
427
  *right_leaf = tree->Split(best_Leaf, best_split_info.feature, 
    best_split_info.threshold,
Guolin Ke's avatar
Guolin Ke committed
428
429
    train_data_->RealFeatureIndex(best_split_info.feature),
    train_data_->RealThreshold(best_split_info.feature, best_split_info.threshold),
430
431
    static_cast<double>(best_split_info.left_output),
    static_cast<double>(best_split_info.right_output),
Guolin Ke's avatar
Guolin Ke committed
432
433
    static_cast<data_size_t>(best_split_info.left_count),
    static_cast<data_size_t>(best_split_info.right_count),
434
    static_cast<double>(best_split_info.gain));
Guolin Ke's avatar
Guolin Ke committed
435
  // split data partition
Guolin Ke's avatar
Guolin Ke committed
436
  data_partition_->Split(best_Leaf, train_data_, best_split_info.feature, 
Guolin Ke's avatar
Guolin Ke committed
437
438
439
440
                         best_split_info.threshold, *right_leaf);

  // init the leaves that used on next iteration
  if (best_split_info.left_count < best_split_info.right_count) {
Guolin Ke's avatar
Guolin Ke committed
441
    smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
442
443
                               best_split_info.left_sum_gradient,
                               best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
444
    larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
445
446
447
                               best_split_info.right_sum_gradient,
                               best_split_info.right_sum_hessian);
  } else {
Guolin Ke's avatar
Guolin Ke committed
448
449
    smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
    larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
450
451
452
  }
}

Guolin Ke's avatar
Guolin Ke committed
453

Guolin Ke's avatar
Guolin Ke committed
454
}  // namespace LightGBM