serial_tree_learner.cpp 17.5 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include "serial_tree_learner.h"

#include <LightGBM/utils/array_args.h>

#include <algorithm>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
Guolin Ke's avatar
Guolin Ke committed
11
  :tree_config_(tree_config){
Guolin Ke's avatar
Guolin Ke committed
12
  random_ = Random(tree_config_->feature_fraction_seed);
Guolin Ke's avatar
Guolin Ke committed
13
14
15
}

SerialTreeLearner::~SerialTreeLearner() {
Guolin Ke's avatar
Guolin Ke committed
16

Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
21
22
}

void SerialTreeLearner::Init(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();
23
24
  int max_cache_size = 0;
  // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
25
26
  if (tree_config_->histogram_pool_size <= 0) {
    max_cache_size = tree_config_->num_leaves;
27
28
29
30
31
  } else {
    size_t total_histogram_size = 0;
    for (int i = 0; i < train_data_->num_features(); ++i) {
      total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
    }
Guolin Ke's avatar
Guolin Ke committed
32
    max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
33
34
  }
  // at least need 2 leaves
Guolin Ke's avatar
Guolin Ke committed
35
  max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
36
37
  max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
  histogram_pool_.Reset(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
38

Guolin Ke's avatar
Guolin Ke committed
39
  auto histogram_create_function = [this]() {
Guolin Ke's avatar
Guolin Ke committed
40
    auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
Guolin Ke's avatar
Guolin Ke committed
41
    for (int j = 0; j < train_data_->num_features(); ++j) {
42
      tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
Guolin Ke's avatar
Guolin Ke committed
43
        j, tree_config_);
Guolin Ke's avatar
Guolin Ke committed
44
    }
Guolin Ke's avatar
Guolin Ke committed
45
    return tmp_histogram_array.release();
Guolin Ke's avatar
Guolin Ke committed
46
47
48
  };
  histogram_pool_.Fill(histogram_create_function);

Guolin Ke's avatar
Guolin Ke committed
49
  // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
50
  best_split_per_leaf_.resize(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
51
  // initialize ordered_bins_ with nullptr
Guolin Ke's avatar
Guolin Ke committed
52
  ordered_bins_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
53
54
55
56

  // get ordered bin
  #pragma omp parallel for schedule(guided)
  for (int i = 0; i < num_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
57
    ordered_bins_[i].reset(train_data_->FeatureAt(i)->bin_data()->CreateOrderedBin());
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
66
  }

  // check existing for ordered bin
  for (int i = 0; i < num_features_; ++i) {
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
wxchan's avatar
wxchan committed
67
  // initialize splits for leaf
Guolin Ke's avatar
Guolin Ke committed
68
69
  smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
  larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
Guolin Ke's avatar
Guolin Ke committed
70
71

  // initialize data partition
Guolin Ke's avatar
Guolin Ke committed
72
  data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
73

Guolin Ke's avatar
Guolin Ke committed
74
  is_feature_used_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
75
76

  // initialize ordered gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
77
78
79
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
Guolin Ke's avatar
Guolin Ke committed
80
  if (has_ordered_bin_) {
Guolin Ke's avatar
Guolin Ke committed
81
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
82
  }
83
  Log::Info("Number of data: %d, number of features: %d", num_data_, num_features_);
Guolin Ke's avatar
Guolin Ke committed
84
85
86
}


Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
  if (tree_config_->num_leaves != tree_config->num_leaves) {
    tree_config_ = tree_config;
    int max_cache_size = 0;
    // Get the max size of pool
    if (tree_config->histogram_pool_size <= 0) {
      max_cache_size = tree_config_->num_leaves;
    } else {
      size_t total_histogram_size = 0;
      for (int i = 0; i < train_data_->num_features(); ++i) {
        total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
      }
      max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
    }
    // at least need 2 leaves
    max_cache_size = std::max(2, max_cache_size);
    max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
    histogram_pool_.DynamicChangeSize(max_cache_size, tree_config_->num_leaves);

    // push split information for all leaves
    best_split_per_leaf_.resize(tree_config_->num_leaves);
    data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
  } else {
    tree_config_ = tree_config;
  }

  histogram_pool_.ResetConfig(tree_config_, train_data_->num_features());
}

Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
  gradients_ = gradients;
  hessians_ = hessians;
  // some initial works before training
  BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
121
  auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
122
  // save pointer to last trained tree
Guolin Ke's avatar
Guolin Ke committed
123
  last_trained_tree_ = tree.get();
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
  // root leaf
  int left_leaf = 0;
  // only root leaf can be splitted on first time
  int right_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
128
  for (int split = 0; split < tree_config_->num_leaves - 1; split++) {
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
135
136
137
138
139
140
141
    // some initial works before finding best split
    if (BeforeFindBestSplit(left_leaf, right_leaf)) {
      // find best threshold for every feature
      FindBestThresholds();
      // find best split from all features
      FindBestSplitsForLeaves();
    }
    // Get a leaf with max split gain
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    // Get split information for best leaf
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
142
      Log::Info("No further splits with positive gain, best gain: %f, leaves: %d",
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
                   best_leaf_SplitInfo.gain, split + 1);
      break;
    }
    // split tree with best leaf
Guolin Ke's avatar
Guolin Ke committed
147
    Split(tree.get(), best_leaf, &left_leaf, &right_leaf);
Guolin Ke's avatar
Guolin Ke committed
148
  }
Guolin Ke's avatar
Guolin Ke committed
149
  return tree.release();
Guolin Ke's avatar
Guolin Ke committed
150
151
152
}

void SerialTreeLearner::BeforeTrain() {
Guolin Ke's avatar
Guolin Ke committed
153

154
155
  // reset histogram pool
  histogram_pool_.ResetMap();
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
  // initialize used features
  for (int i = 0; i < num_features_; ++i) {
    is_feature_used_[i] = false;
  }
  // Get used feature at current tree
Guolin Ke's avatar
Guolin Ke committed
161
  int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
162
  auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
Guolin Ke's avatar
Guolin Ke committed
163
164
165
  for (auto idx : used_feature_indices) {
    is_feature_used_[idx] = true;
  }
166

Guolin Ke's avatar
Guolin Ke committed
167
168
169
170
  // initialize data partition
  data_partition_->Init();

  // reset the splits for leaves
Guolin Ke's avatar
Guolin Ke committed
171
  for (int i = 0; i < tree_config_->num_leaves; ++i) {
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
    best_split_per_leaf_[i].Reset();
  }

  // Sumup for root
  if (data_partition_->leaf_count(0) == num_data_) {
    // use all data
    smaller_leaf_splits_->Init(gradients_, hessians_);
    // point to gradients, avoid copy
180
181
    ptr_to_ordered_gradients_smaller_leaf_ = gradients_;
    ptr_to_ordered_hessians_smaller_leaf_  = hessians_;
Guolin Ke's avatar
Guolin Ke committed
182
183
  } else {
    // use bagging, only use part of data
Guolin Ke's avatar
Guolin Ke committed
184
    smaller_leaf_splits_->Init(0, data_partition_.get(), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
185
186
187
188
189
190
191
192
193
    // copy used gradients and hessians to ordered buffer
    const data_size_t* indices = data_partition_->indices();
    data_size_t cnt = data_partition_->leaf_count(0);
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < cnt; ++i) {
      ordered_gradients_[i] = gradients_[indices[i]];
      ordered_hessians_[i] = hessians_[indices[i]];
    }
    // point to ordered_gradients_ and ordered_hessians_
Guolin Ke's avatar
Guolin Ke committed
194
195
    ptr_to_ordered_gradients_smaller_leaf_ = ordered_gradients_.data();
    ptr_to_ordered_hessians_smaller_leaf_ = ordered_hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
196
197
  }

198
199
200
  ptr_to_ordered_gradients_larger_leaf_ = nullptr;
  ptr_to_ordered_hessians_larger_leaf_ = nullptr;

Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
205
206
207
208
209
  larger_leaf_splits_->Init();

  // if has ordered bin, need to initialize the ordered bin
  if (has_ordered_bin_) {
    if (data_partition_->leaf_count(0) == num_data_) {
      // use all data, pass nullptr
      #pragma omp parallel for schedule(guided)
      for (int i = 0; i < num_features_; ++i) {
        if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
210
          ordered_bins_[i]->Init(nullptr, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
211
212
213
214
215
216
        }
      }
    } else {
      // bagging, only use part of data

      // mark used data
Guolin Ke's avatar
Guolin Ke committed
217
      std::memset(is_data_in_leaf_.data(), 0, sizeof(char)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
218
219
220
221
222
223
224
225
226
227
228
      const data_size_t* indices = data_partition_->indices();
      data_size_t begin = data_partition_->leaf_begin(0);
      data_size_t end = begin + data_partition_->leaf_count(0);
      #pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 1;
      }
      // initialize ordered bin
      #pragma omp parallel for schedule(guided)
      for (int i = 0; i < num_features_; ++i) {
        if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
229
          ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
234
235
236
        }
      }
    }
  }
}

bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
Guolin Ke's avatar
Guolin Ke committed
237
  // check depth of current leaf
Guolin Ke's avatar
Guolin Ke committed
238
  if (tree_config_->max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
239
    // only need to check left leaf, since right leaf is in same level of left leaf
Guolin Ke's avatar
Guolin Ke committed
240
    if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
Guolin Ke's avatar
Guolin Ke committed
241
242
243
244
245
246
247
      best_split_per_leaf_[left_leaf].gain = kMinScore;
      if (right_leaf >= 0) {
        best_split_per_leaf_[right_leaf].gain = kMinScore;
      }
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
248
249
250
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // no enough data to continue
Guolin Ke's avatar
Guolin Ke committed
251
252
  if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
    && num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
Guolin Ke's avatar
Guolin Ke committed
253
254
255
256
257
258
    best_split_per_leaf_[left_leaf].gain = kMinScore;
    if (right_leaf >= 0) {
      best_split_per_leaf_[right_leaf].gain = kMinScore;
    }
    return false;
  }
259
  parent_leaf_histogram_array_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
260
261
  // -1 if only has one leaf. else equal the index of smaller leaf
  int smaller_leaf = -1;
262
  int larger_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
263
264
  // only have root
  if (right_leaf < 0) {
265
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
266
    larger_leaf_histogram_array_ = nullptr;
267

Guolin Ke's avatar
Guolin Ke committed
268
269
  } else if (num_data_in_left_child < num_data_in_right_child) {
    smaller_leaf = left_leaf;
270
    larger_leaf = right_leaf;
Hui Xue's avatar
Hui Xue committed
271
    // put parent(left) leaf's histograms into larger leaf's histograms
272
273
274
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Move(left_leaf, right_leaf);
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
275
276
  } else {
    smaller_leaf = right_leaf;
277
    larger_leaf = left_leaf;
Hui Xue's avatar
Hui Xue committed
278
    // put parent(left) leaf's histograms to larger leaf's histograms
279
280
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Get(right_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
  }

  // init for the ordered gradients, only initialize when have 2 leaves
  if (smaller_leaf >= 0) {
    // only need to initialize for smaller leaf

    // Get leaf boundary
    const data_size_t* indices = data_partition_->indices();
    data_size_t begin = data_partition_->leaf_begin(smaller_leaf);
    data_size_t end = begin + data_partition_->leaf_count(smaller_leaf);
    // copy
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      ordered_gradients_[i - begin] = gradients_[indices[i]];
      ordered_hessians_[i - begin] = hessians_[indices[i]];
    }
    // assign pointer
Guolin Ke's avatar
Guolin Ke committed
298
299
    ptr_to_ordered_gradients_smaller_leaf_ = ordered_gradients_.data();
    ptr_to_ordered_hessians_smaller_leaf_ = ordered_hessians_.data();
300
301
302
303
304
305
306
307
308
309
310
311

    if (parent_leaf_histogram_array_ == nullptr) {
      // need order gradient for larger leaf
      data_size_t smaller_size = end - begin;
      data_size_t larger_begin = data_partition_->leaf_begin(larger_leaf);
      data_size_t larger_end = larger_begin + data_partition_->leaf_count(larger_leaf);
      // copy
      #pragma omp parallel for schedule(static)
      for (data_size_t i = larger_begin; i < larger_end; ++i) {
        ordered_gradients_[smaller_size + i - larger_begin] = gradients_[indices[i]];
        ordered_hessians_[smaller_size + i - larger_begin] = hessians_[indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
312
313
      ptr_to_ordered_gradients_larger_leaf_ = ordered_gradients_.data() + smaller_size;
      ptr_to_ordered_hessians_larger_leaf_ = ordered_hessians_.data() + smaller_size;
314
    }
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
319
  }

  // split for the ordered bin
  if (has_ordered_bin_ && right_leaf >= 0) {
    // mark data that at left-leaf
Guolin Ke's avatar
Guolin Ke committed
320
    std::memset(is_data_in_leaf_.data(), 0, sizeof(char)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
325
326
327
328
329
330
331
    const data_size_t* indices = data_partition_->indices();
    data_size_t begin = data_partition_->leaf_begin(left_leaf);
    data_size_t end = begin + data_partition_->leaf_count(left_leaf);
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 1;
    }
    // split the ordered bin
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_features_; ++i) {
      if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
332
        ordered_bins_[i]->Split(left_leaf, right_leaf, is_data_in_leaf_.data());
Guolin Ke's avatar
Guolin Ke committed
333
334
335
336
337
338
339
340
341
342
343
      }
    }
  }
  return true;
}


void SerialTreeLearner::FindBestThresholds() {
  #pragma omp parallel for schedule(guided)
  for (int feature_index = 0; feature_index < num_features_; feature_index++) {
    // feature is not used
Guolin Ke's avatar
Guolin Ke committed
344
    if ((!is_feature_used_.empty() && is_feature_used_[feature_index] == false)) continue;
Guolin Ke's avatar
Guolin Ke committed
345
    // if parent(larger) leaf cannot split at current feature
346
    if (parent_leaf_histogram_array_ != nullptr && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
Guolin Ke's avatar
Guolin Ke committed
347
348
349
350
351
352
353
354
355
356
357
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }

    // construct histograms for smaller leaf
    if (ordered_bins_[feature_index] == nullptr) {
      // if not use ordered bin
      smaller_leaf_histogram_array_[feature_index].Construct(smaller_leaf_splits_->data_indices(),
        smaller_leaf_splits_->num_data_in_leaf(),
        smaller_leaf_splits_->sum_gradients(),
        smaller_leaf_splits_->sum_hessians(),
358
359
        ptr_to_ordered_gradients_smaller_leaf_,
        ptr_to_ordered_hessians_smaller_leaf_);
Guolin Ke's avatar
Guolin Ke committed
360
361
    } else {
      // used ordered bin
Guolin Ke's avatar
Guolin Ke committed
362
      smaller_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(),
Guolin Ke's avatar
Guolin Ke committed
363
364
365
366
367
368
369
370
371
372
373
374
375
        smaller_leaf_splits_->LeafIndex(),
        smaller_leaf_splits_->num_data_in_leaf(),
        smaller_leaf_splits_->sum_gradients(),
        smaller_leaf_splits_->sum_hessians(),
        gradients_,
        hessians_);
    }
    // find best threshold for smaller child
    smaller_leaf_histogram_array_[feature_index].FindBestThreshold(&smaller_leaf_splits_->BestSplitPerFeature()[feature_index]);

    // only has root leaf
    if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) continue;

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    if (parent_leaf_histogram_array_ != nullptr) {
      // construct histgroms for large leaf, we initialize larger leaf as the parent,
      // so we can just subtract the smaller leaf's histograms
      larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
    } else {
      if (ordered_bins_[feature_index] == nullptr) {
        // if not use ordered bin
        larger_leaf_histogram_array_[feature_index].Construct(larger_leaf_splits_->data_indices(),
          larger_leaf_splits_->num_data_in_leaf(),
          larger_leaf_splits_->sum_gradients(),
          larger_leaf_splits_->sum_hessians(),
          ptr_to_ordered_gradients_larger_leaf_,
          ptr_to_ordered_hessians_larger_leaf_);
      } else {
        // used ordered bin
Guolin Ke's avatar
Guolin Ke committed
391
        larger_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(),
392
393
394
395
396
397
398
399
          larger_leaf_splits_->LeafIndex(),
          larger_leaf_splits_->num_data_in_leaf(),
          larger_leaf_splits_->sum_gradients(),
          larger_leaf_splits_->sum_hessians(),
          gradients_,
          hessians_);
      }
    }
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
404
405
406
407
408
409
410
411
412

    // find best threshold for larger child
    larger_leaf_histogram_array_[feature_index].FindBestThreshold(&larger_leaf_splits_->BestSplitPerFeature()[feature_index]);
  }
}


void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
  const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];

  // left = parent
  *left_leaf = best_Leaf;
  // split tree, will return right leaf
Guolin Ke's avatar
Guolin Ke committed
413
414
415
  *right_leaf = tree->Split(best_Leaf, best_split_info.feature, 
    train_data_->FeatureAt(best_split_info.feature)->bin_type(),
    best_split_info.threshold,
Guolin Ke's avatar
Guolin Ke committed
416
417
    train_data_->FeatureAt(best_split_info.feature)->feature_index(),
    train_data_->FeatureAt(best_split_info.feature)->BinToValue(best_split_info.threshold),
418
419
    static_cast<double>(best_split_info.left_output),
    static_cast<double>(best_split_info.right_output),
Guolin Ke's avatar
Guolin Ke committed
420
421
    static_cast<data_size_t>(best_split_info.left_count),
    static_cast<data_size_t>(best_split_info.right_count),
422
    static_cast<double>(best_split_info.gain));
Guolin Ke's avatar
Guolin Ke committed
423
424
425
426
427
428
429

  // split data partition
  data_partition_->Split(best_Leaf, train_data_->FeatureAt(best_split_info.feature)->bin_data(),
                         best_split_info.threshold, *right_leaf);

  // init the leaves that used on next iteration
  if (best_split_info.left_count < best_split_info.right_count) {
Guolin Ke's avatar
Guolin Ke committed
430
    smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
431
432
                               best_split_info.left_sum_gradient,
                               best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
433
    larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
434
435
436
                               best_split_info.right_sum_gradient,
                               best_split_info.right_sum_hessian);
  } else {
Guolin Ke's avatar
Guolin Ke committed
437
438
    smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
    larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
439
440
441
442
  }
}

}  // namespace LightGBM