"src/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "011fe02487a13d33058fe922b08921c8a71081eb"
serial_tree_learner.cpp 17.5 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include "serial_tree_learner.h"

#include <LightGBM/utils/array_args.h>

#include <algorithm>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
Guolin Ke's avatar
Guolin Ke committed
11
  :tree_config_(tree_config){
Guolin Ke's avatar
Guolin Ke committed
12
  random_ = Random(tree_config_->feature_fraction_seed);
Guolin Ke's avatar
Guolin Ke committed
13
14
15
}

SerialTreeLearner::~SerialTreeLearner() {
Guolin Ke's avatar
Guolin Ke committed
16

Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
21
22
}

void SerialTreeLearner::Init(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();
23
24
  int max_cache_size = 0;
  // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
25
26
  if (tree_config_->histogram_pool_size <= 0) {
    max_cache_size = tree_config_->num_leaves;
27
28
29
30
31
  } else {
    size_t total_histogram_size = 0;
    for (int i = 0; i < train_data_->num_features(); ++i) {
      total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
    }
Guolin Ke's avatar
Guolin Ke committed
32
    max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
33
34
  }
  // at least need 2 leaves
Guolin Ke's avatar
Guolin Ke committed
35
  max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
36
37
  max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
  histogram_pool_.Reset(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
38

Guolin Ke's avatar
Guolin Ke committed
39
  auto histogram_create_function = [this]() {
Guolin Ke's avatar
Guolin Ke committed
40
    auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
Guolin Ke's avatar
Guolin Ke committed
41
    for (int j = 0; j < train_data_->num_features(); ++j) {
42
      tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
Guolin Ke's avatar
Guolin Ke committed
43
        j, tree_config_);
Guolin Ke's avatar
Guolin Ke committed
44
    }
Guolin Ke's avatar
Guolin Ke committed
45
    return tmp_histogram_array.release();
Guolin Ke's avatar
Guolin Ke committed
46
47
48
  };
  histogram_pool_.Fill(histogram_create_function);

Guolin Ke's avatar
Guolin Ke committed
49
  // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
50
  best_split_per_leaf_.resize(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
51
  // initialize ordered_bins_ with nullptr
Guolin Ke's avatar
Guolin Ke committed
52
  ordered_bins_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
53
54
55
56

  // get ordered bin
  #pragma omp parallel for schedule(guided)
  for (int i = 0; i < num_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
57
    ordered_bins_[i].reset(train_data_->FeatureAt(i)->bin_data()->CreateOrderedBin());
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
66
  }

  // check existing for ordered bin
  for (int i = 0; i < num_features_; ++i) {
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
wxchan's avatar
wxchan committed
67
  // initialize splits for leaf
Guolin Ke's avatar
Guolin Ke committed
68
69
  smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
  larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
Guolin Ke's avatar
Guolin Ke committed
70
71

  // initialize data partition
Guolin Ke's avatar
Guolin Ke committed
72
  data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
73

Guolin Ke's avatar
Guolin Ke committed
74
  is_feature_used_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
75
76

  // initialize ordered gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
77
78
79
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
Guolin Ke's avatar
Guolin Ke committed
80
  if (has_ordered_bin_) {
Guolin Ke's avatar
Guolin Ke committed
81
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
82
  }
83
  Log::Info("Number of data: %d, number of features: %d", num_data_, num_features_);
Guolin Ke's avatar
Guolin Ke committed
84
85
86
}


Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
  if (tree_config_->num_leaves != tree_config->num_leaves) {
    tree_config_ = tree_config;
    int max_cache_size = 0;
    // Get the max size of pool
    if (tree_config->histogram_pool_size <= 0) {
      max_cache_size = tree_config_->num_leaves;
    } else {
      size_t total_histogram_size = 0;
      for (int i = 0; i < train_data_->num_features(); ++i) {
        total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
      }
      max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
    }
    // at least need 2 leaves
    max_cache_size = std::max(2, max_cache_size);
    max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
    histogram_pool_.DynamicChangeSize(max_cache_size, tree_config_->num_leaves);

    // push split information for all leaves
    best_split_per_leaf_.resize(tree_config_->num_leaves);
    data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
  } else {
    tree_config_ = tree_config;
  }

  histogram_pool_.ResetConfig(tree_config_, train_data_->num_features());
  random_ = Random(tree_config_->feature_fraction_seed);
}

Guolin Ke's avatar
Guolin Ke committed
117
118
119
120
121
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
  gradients_ = gradients;
  hessians_ = hessians;
  // some initial works before training
  BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
122
  auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
123
  // save pointer to last trained tree
Guolin Ke's avatar
Guolin Ke committed
124
  last_trained_tree_ = tree.get();
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
  // root leaf
  int left_leaf = 0;
  // only root leaf can be splitted on first time
  int right_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
129
  for (int split = 0; split < tree_config_->num_leaves - 1; split++) {
Guolin Ke's avatar
Guolin Ke committed
130
131
132
133
134
135
136
137
138
139
140
141
142
    // some initial works before finding best split
    if (BeforeFindBestSplit(left_leaf, right_leaf)) {
      // find best threshold for every feature
      FindBestThresholds();
      // find best split from all features
      FindBestSplitsForLeaves();
    }
    // Get a leaf with max split gain
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    // Get split information for best leaf
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
143
      Log::Info("No further splits with positive gain, best gain: %f, leaves: %d",
Guolin Ke's avatar
Guolin Ke committed
144
145
146
147
                   best_leaf_SplitInfo.gain, split + 1);
      break;
    }
    // split tree with best leaf
Guolin Ke's avatar
Guolin Ke committed
148
    Split(tree.get(), best_leaf, &left_leaf, &right_leaf);
Guolin Ke's avatar
Guolin Ke committed
149
  }
Guolin Ke's avatar
Guolin Ke committed
150
  return tree.release();
Guolin Ke's avatar
Guolin Ke committed
151
152
153
}

void SerialTreeLearner::BeforeTrain() {
Guolin Ke's avatar
Guolin Ke committed
154

155
156
  // reset histogram pool
  histogram_pool_.ResetMap();
Guolin Ke's avatar
Guolin Ke committed
157
158
159
160
161
  // initialize used features
  for (int i = 0; i < num_features_; ++i) {
    is_feature_used_[i] = false;
  }
  // Get used feature at current tree
Guolin Ke's avatar
Guolin Ke committed
162
  int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
163
  auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
Guolin Ke's avatar
Guolin Ke committed
164
165
166
  for (auto idx : used_feature_indices) {
    is_feature_used_[idx] = true;
  }
167

Guolin Ke's avatar
Guolin Ke committed
168
169
170
171
  // initialize data partition
  data_partition_->Init();

  // reset the splits for leaves
Guolin Ke's avatar
Guolin Ke committed
172
  for (int i = 0; i < tree_config_->num_leaves; ++i) {
Guolin Ke's avatar
Guolin Ke committed
173
174
175
176
177
178
179
180
    best_split_per_leaf_[i].Reset();
  }

  // Sumup for root
  if (data_partition_->leaf_count(0) == num_data_) {
    // use all data
    smaller_leaf_splits_->Init(gradients_, hessians_);
    // point to gradients, avoid copy
181
182
    ptr_to_ordered_gradients_smaller_leaf_ = gradients_;
    ptr_to_ordered_hessians_smaller_leaf_  = hessians_;
Guolin Ke's avatar
Guolin Ke committed
183
184
  } else {
    // use bagging, only use part of data
Guolin Ke's avatar
Guolin Ke committed
185
    smaller_leaf_splits_->Init(0, data_partition_.get(), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
190
191
192
193
194
    // copy used gradients and hessians to ordered buffer
    const data_size_t* indices = data_partition_->indices();
    data_size_t cnt = data_partition_->leaf_count(0);
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < cnt; ++i) {
      ordered_gradients_[i] = gradients_[indices[i]];
      ordered_hessians_[i] = hessians_[indices[i]];
    }
    // point to ordered_gradients_ and ordered_hessians_
Guolin Ke's avatar
Guolin Ke committed
195
196
    ptr_to_ordered_gradients_smaller_leaf_ = ordered_gradients_.data();
    ptr_to_ordered_hessians_smaller_leaf_ = ordered_hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
197
198
  }

199
200
201
  ptr_to_ordered_gradients_larger_leaf_ = nullptr;
  ptr_to_ordered_hessians_larger_leaf_ = nullptr;

Guolin Ke's avatar
Guolin Ke committed
202
203
204
205
206
207
208
209
210
  larger_leaf_splits_->Init();

  // if has ordered bin, need to initialize the ordered bin
  if (has_ordered_bin_) {
    if (data_partition_->leaf_count(0) == num_data_) {
      // use all data, pass nullptr
      #pragma omp parallel for schedule(guided)
      for (int i = 0; i < num_features_; ++i) {
        if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
211
          ordered_bins_[i]->Init(nullptr, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
216
217
        }
      }
    } else {
      // bagging, only use part of data

      // mark used data
Guolin Ke's avatar
Guolin Ke committed
218
      std::memset(is_data_in_leaf_.data(), 0, sizeof(char)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
223
224
225
226
227
228
229
      const data_size_t* indices = data_partition_->indices();
      data_size_t begin = data_partition_->leaf_begin(0);
      data_size_t end = begin + data_partition_->leaf_count(0);
      #pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 1;
      }
      // initialize ordered bin
      #pragma omp parallel for schedule(guided)
      for (int i = 0; i < num_features_; ++i) {
        if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
230
          ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
235
236
237
        }
      }
    }
  }
}

bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
Guolin Ke's avatar
Guolin Ke committed
238
  // check depth of current leaf
Guolin Ke's avatar
Guolin Ke committed
239
  if (tree_config_->max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
240
    // only need to check left leaf, since right leaf is in same level of left leaf
Guolin Ke's avatar
Guolin Ke committed
241
    if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
Guolin Ke's avatar
Guolin Ke committed
242
243
244
245
246
247
248
      best_split_per_leaf_[left_leaf].gain = kMinScore;
      if (right_leaf >= 0) {
        best_split_per_leaf_[right_leaf].gain = kMinScore;
      }
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
249
250
251
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // no enough data to continue
Guolin Ke's avatar
Guolin Ke committed
252
253
  if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
    && num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
258
259
    best_split_per_leaf_[left_leaf].gain = kMinScore;
    if (right_leaf >= 0) {
      best_split_per_leaf_[right_leaf].gain = kMinScore;
    }
    return false;
  }
260
  parent_leaf_histogram_array_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
261
262
  // -1 if only has one leaf. else equal the index of smaller leaf
  int smaller_leaf = -1;
263
  int larger_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
264
265
  // only have root
  if (right_leaf < 0) {
266
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
267
    larger_leaf_histogram_array_ = nullptr;
268

Guolin Ke's avatar
Guolin Ke committed
269
270
  } else if (num_data_in_left_child < num_data_in_right_child) {
    smaller_leaf = left_leaf;
271
    larger_leaf = right_leaf;
Hui Xue's avatar
Hui Xue committed
272
    // put parent(left) leaf's histograms into larger leaf's histograms
273
274
275
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Move(left_leaf, right_leaf);
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
276
277
  } else {
    smaller_leaf = right_leaf;
278
    larger_leaf = left_leaf;
Hui Xue's avatar
Hui Xue committed
279
    // put parent(left) leaf's histograms to larger leaf's histograms
280
281
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Get(right_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
  }

  // init for the ordered gradients, only initialize when have 2 leaves
  if (smaller_leaf >= 0) {
    // only need to initialize for smaller leaf

    // Get leaf boundary
    const data_size_t* indices = data_partition_->indices();
    data_size_t begin = data_partition_->leaf_begin(smaller_leaf);
    data_size_t end = begin + data_partition_->leaf_count(smaller_leaf);
    // copy
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      ordered_gradients_[i - begin] = gradients_[indices[i]];
      ordered_hessians_[i - begin] = hessians_[indices[i]];
    }
    // assign pointer
Guolin Ke's avatar
Guolin Ke committed
299
300
    ptr_to_ordered_gradients_smaller_leaf_ = ordered_gradients_.data();
    ptr_to_ordered_hessians_smaller_leaf_ = ordered_hessians_.data();
301
302
303
304
305
306
307
308
309
310
311
312

    if (parent_leaf_histogram_array_ == nullptr) {
      // need order gradient for larger leaf
      data_size_t smaller_size = end - begin;
      data_size_t larger_begin = data_partition_->leaf_begin(larger_leaf);
      data_size_t larger_end = larger_begin + data_partition_->leaf_count(larger_leaf);
      // copy
      #pragma omp parallel for schedule(static)
      for (data_size_t i = larger_begin; i < larger_end; ++i) {
        ordered_gradients_[smaller_size + i - larger_begin] = gradients_[indices[i]];
        ordered_hessians_[smaller_size + i - larger_begin] = hessians_[indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
313
314
      ptr_to_ordered_gradients_larger_leaf_ = ordered_gradients_.data() + smaller_size;
      ptr_to_ordered_hessians_larger_leaf_ = ordered_hessians_.data() + smaller_size;
315
    }
Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
320
  }

  // split for the ordered bin
  if (has_ordered_bin_ && right_leaf >= 0) {
    // mark data that at left-leaf
Guolin Ke's avatar
Guolin Ke committed
321
    std::memset(is_data_in_leaf_.data(), 0, sizeof(char)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
322
323
324
325
326
327
328
329
330
331
332
    const data_size_t* indices = data_partition_->indices();
    data_size_t begin = data_partition_->leaf_begin(left_leaf);
    data_size_t end = begin + data_partition_->leaf_count(left_leaf);
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 1;
    }
    // split the ordered bin
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_features_; ++i) {
      if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
333
        ordered_bins_[i]->Split(left_leaf, right_leaf, is_data_in_leaf_.data());
Guolin Ke's avatar
Guolin Ke committed
334
335
336
337
338
339
340
341
342
343
344
      }
    }
  }
  return true;
}


void SerialTreeLearner::FindBestThresholds() {
  #pragma omp parallel for schedule(guided)
  for (int feature_index = 0; feature_index < num_features_; feature_index++) {
    // feature is not used
Guolin Ke's avatar
Guolin Ke committed
345
    if ((!is_feature_used_.empty() && is_feature_used_[feature_index] == false)) continue;
Guolin Ke's avatar
Guolin Ke committed
346
    // if parent(larger) leaf cannot split at current feature
347
    if (parent_leaf_histogram_array_ != nullptr && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
Guolin Ke's avatar
Guolin Ke committed
348
349
350
351
352
353
354
355
356
357
358
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }

    // construct histograms for smaller leaf
    if (ordered_bins_[feature_index] == nullptr) {
      // if not use ordered bin
      smaller_leaf_histogram_array_[feature_index].Construct(smaller_leaf_splits_->data_indices(),
        smaller_leaf_splits_->num_data_in_leaf(),
        smaller_leaf_splits_->sum_gradients(),
        smaller_leaf_splits_->sum_hessians(),
359
360
        ptr_to_ordered_gradients_smaller_leaf_,
        ptr_to_ordered_hessians_smaller_leaf_);
Guolin Ke's avatar
Guolin Ke committed
361
362
    } else {
      // used ordered bin
Guolin Ke's avatar
Guolin Ke committed
363
      smaller_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(),
Guolin Ke's avatar
Guolin Ke committed
364
365
366
367
368
369
370
371
372
373
374
375
376
        smaller_leaf_splits_->LeafIndex(),
        smaller_leaf_splits_->num_data_in_leaf(),
        smaller_leaf_splits_->sum_gradients(),
        smaller_leaf_splits_->sum_hessians(),
        gradients_,
        hessians_);
    }
    // find best threshold for smaller child
    smaller_leaf_histogram_array_[feature_index].FindBestThreshold(&smaller_leaf_splits_->BestSplitPerFeature()[feature_index]);

    // only has root leaf
    if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) continue;

377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
    if (parent_leaf_histogram_array_ != nullptr) {
      // construct histgroms for large leaf, we initialize larger leaf as the parent,
      // so we can just subtract the smaller leaf's histograms
      larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
    } else {
      if (ordered_bins_[feature_index] == nullptr) {
        // if not use ordered bin
        larger_leaf_histogram_array_[feature_index].Construct(larger_leaf_splits_->data_indices(),
          larger_leaf_splits_->num_data_in_leaf(),
          larger_leaf_splits_->sum_gradients(),
          larger_leaf_splits_->sum_hessians(),
          ptr_to_ordered_gradients_larger_leaf_,
          ptr_to_ordered_hessians_larger_leaf_);
      } else {
        // used ordered bin
Guolin Ke's avatar
Guolin Ke committed
392
        larger_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(),
393
394
395
396
397
398
399
400
          larger_leaf_splits_->LeafIndex(),
          larger_leaf_splits_->num_data_in_leaf(),
          larger_leaf_splits_->sum_gradients(),
          larger_leaf_splits_->sum_hessians(),
          gradients_,
          hessians_);
      }
    }
Guolin Ke's avatar
Guolin Ke committed
401
402
403
404
405
406
407
408
409
410
411
412
413

    // find best threshold for larger child
    larger_leaf_histogram_array_[feature_index].FindBestThreshold(&larger_leaf_splits_->BestSplitPerFeature()[feature_index]);
  }
}


void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
  const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];

  // left = parent
  *left_leaf = best_Leaf;
  // split tree, will return right leaf
Guolin Ke's avatar
Guolin Ke committed
414
415
416
  *right_leaf = tree->Split(best_Leaf, best_split_info.feature, 
    train_data_->FeatureAt(best_split_info.feature)->bin_type(),
    best_split_info.threshold,
Guolin Ke's avatar
Guolin Ke committed
417
418
    train_data_->FeatureAt(best_split_info.feature)->feature_index(),
    train_data_->FeatureAt(best_split_info.feature)->BinToValue(best_split_info.threshold),
419
420
    static_cast<double>(best_split_info.left_output),
    static_cast<double>(best_split_info.right_output),
Guolin Ke's avatar
Guolin Ke committed
421
422
    static_cast<data_size_t>(best_split_info.left_count),
    static_cast<data_size_t>(best_split_info.right_count),
423
    static_cast<double>(best_split_info.gain));
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
428
429
430

  // split data partition
  data_partition_->Split(best_Leaf, train_data_->FeatureAt(best_split_info.feature)->bin_data(),
                         best_split_info.threshold, *right_leaf);

  // init the leaves that used on next iteration
  if (best_split_info.left_count < best_split_info.right_count) {
Guolin Ke's avatar
Guolin Ke committed
431
    smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
432
433
                               best_split_info.left_sum_gradient,
                               best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
434
    larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
435
436
437
                               best_split_info.right_sum_gradient,
                               best_split_info.right_sum_hessian);
  } else {
Guolin Ke's avatar
Guolin Ke committed
438
439
    smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
    larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
440
441
442
443
  }
}

}  // namespace LightGBM