"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "668bf5dadf1eb9a846302b2b76a313fbbef52870"
serial_tree_learner.cpp 17.4 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
#include "serial_tree_learner.h"

#include <LightGBM/utils/array_args.h>

#include <algorithm>
#include <vector>

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
10
SerialTreeLearner::SerialTreeLearner(const TreeConfig* tree_config)
Guolin Ke's avatar
Guolin Ke committed
11
  :tree_config_(tree_config){
Guolin Ke's avatar
Guolin Ke committed
12
  random_ = Random(tree_config_->feature_fraction_seed);
Guolin Ke's avatar
Guolin Ke committed
13
14
15
}

SerialTreeLearner::~SerialTreeLearner() {
Guolin Ke's avatar
Guolin Ke committed
16

Guolin Ke's avatar
Guolin Ke committed
17
18
19
20
21
22
}

void SerialTreeLearner::Init(const Dataset* train_data) {
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();
23
24
  int max_cache_size = 0;
  // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
25
26
  if (tree_config_->histogram_pool_size <= 0) {
    max_cache_size = tree_config_->num_leaves;
27
28
29
30
31
  } else {
    size_t total_histogram_size = 0;
    for (int i = 0; i < train_data_->num_features(); ++i) {
      total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
    }
Guolin Ke's avatar
Guolin Ke committed
32
    max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
33
34
  }
  // at least need 2 leaves
Guolin Ke's avatar
Guolin Ke committed
35
  max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
36
37
  max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
  histogram_pool_.Reset(max_cache_size, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
38

Guolin Ke's avatar
Guolin Ke committed
39
  auto histogram_create_function = [this]() {
Guolin Ke's avatar
Guolin Ke committed
40
    auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
Guolin Ke's avatar
Guolin Ke committed
41
    for (int j = 0; j < train_data_->num_features(); ++j) {
42
      tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
Guolin Ke's avatar
Guolin Ke committed
43
        j, tree_config_);
Guolin Ke's avatar
Guolin Ke committed
44
    }
Guolin Ke's avatar
Guolin Ke committed
45
    return tmp_histogram_array.release();
Guolin Ke's avatar
Guolin Ke committed
46
47
48
  };
  histogram_pool_.Fill(histogram_create_function);

Guolin Ke's avatar
Guolin Ke committed
49
  // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
50
  best_split_per_leaf_.resize(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
51
  // initialize ordered_bins_ with nullptr
Guolin Ke's avatar
Guolin Ke committed
52
  ordered_bins_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
53
54
55
56

  // get ordered bin
  #pragma omp parallel for schedule(guided)
  for (int i = 0; i < num_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
57
    ordered_bins_[i].reset(train_data_->FeatureAt(i)->bin_data()->CreateOrderedBin());
Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
65
66
  }

  // check existing for ordered bin
  for (int i = 0; i < num_features_; ++i) {
    if (ordered_bins_[i] != nullptr) {
      has_ordered_bin_ = true;
      break;
    }
  }
wxchan's avatar
wxchan committed
67
  // initialize splits for leaf
Guolin Ke's avatar
Guolin Ke committed
68
69
  smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
  larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
Guolin Ke's avatar
Guolin Ke committed
70
71

  // initialize data partition
Guolin Ke's avatar
Guolin Ke committed
72
  data_partition_.reset(new DataPartition(num_data_, tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
73

Guolin Ke's avatar
Guolin Ke committed
74
  is_feature_used_.resize(num_features_);
Guolin Ke's avatar
Guolin Ke committed
75
76

  // initialize ordered gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
77
78
79
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
  // if has ordered bin, need to allocate a buffer to fast split
Guolin Ke's avatar
Guolin Ke committed
80
  if (has_ordered_bin_) {
Guolin Ke's avatar
Guolin Ke committed
81
    is_data_in_leaf_.resize(num_data_);
Guolin Ke's avatar
Guolin Ke committed
82
  }
83
  Log::Info("Number of data: %d, number of features: %d", num_data_, num_features_);
Guolin Ke's avatar
Guolin Ke committed
84
85
86
}


Guolin Ke's avatar
Guolin Ke committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
void SerialTreeLearner::ResetConfig(const TreeConfig* tree_config) {
  if (tree_config_->num_leaves != tree_config->num_leaves) {
    tree_config_ = tree_config;
    int max_cache_size = 0;
    // Get the max size of pool
    if (tree_config->histogram_pool_size <= 0) {
      max_cache_size = tree_config_->num_leaves;
    } else {
      size_t total_histogram_size = 0;
      for (int i = 0; i < train_data_->num_features(); ++i) {
        total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
      }
      max_cache_size = static_cast<int>(tree_config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
    }
    // at least need 2 leaves
    max_cache_size = std::max(2, max_cache_size);
    max_cache_size = std::min(max_cache_size, tree_config_->num_leaves);
    histogram_pool_.DynamicChangeSize(max_cache_size, tree_config_->num_leaves);

    // push split information for all leaves
    best_split_per_leaf_.resize(tree_config_->num_leaves);
108
    data_partition_->ResetLeaves(tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
115
  } else {
    tree_config_ = tree_config;
  }

  histogram_pool_.ResetConfig(tree_config_, train_data_->num_features());
}

Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
  gradients_ = gradients;
  hessians_ = hessians;
  // some initial works before training
  BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
121
  auto tree = std::unique_ptr<Tree>(new Tree(tree_config_->num_leaves));
Guolin Ke's avatar
Guolin Ke committed
122
  // save pointer to last trained tree
Guolin Ke's avatar
Guolin Ke committed
123
  last_trained_tree_ = tree.get();
Guolin Ke's avatar
Guolin Ke committed
124
125
126
127
  // root leaf
  int left_leaf = 0;
  // only root leaf can be splitted on first time
  int right_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
128
  for (int split = 0; split < tree_config_->num_leaves - 1; split++) {
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
135
136
137
138
139
140
141
    // some initial works before finding best split
    if (BeforeFindBestSplit(left_leaf, right_leaf)) {
      // find best threshold for every feature
      FindBestThresholds();
      // find best split from all features
      FindBestSplitsForLeaves();
    }
    // Get a leaf with max split gain
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    // Get split information for best leaf
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
142
      Log::Info("No further splits with positive gain, best gain: %f, leaves: %d",
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
                   best_leaf_SplitInfo.gain, split + 1);
      break;
    }
    // split tree with best leaf
Guolin Ke's avatar
Guolin Ke committed
147
    Split(tree.get(), best_leaf, &left_leaf, &right_leaf);
Guolin Ke's avatar
Guolin Ke committed
148
  }
Guolin Ke's avatar
Guolin Ke committed
149
  return tree.release();
Guolin Ke's avatar
Guolin Ke committed
150
151
152
}

void SerialTreeLearner::BeforeTrain() {
Guolin Ke's avatar
Guolin Ke committed
153

154
155
  // reset histogram pool
  histogram_pool_.ResetMap();
Guolin Ke's avatar
Guolin Ke committed
156
157
158
159
160
  // initialize used features
  for (int i = 0; i < num_features_; ++i) {
    is_feature_used_[i] = false;
  }
  // Get used feature at current tree
Guolin Ke's avatar
Guolin Ke committed
161
  int used_feature_cnt = static_cast<int>(num_features_*tree_config_->feature_fraction);
162
  auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
Guolin Ke's avatar
Guolin Ke committed
163
164
165
  for (auto idx : used_feature_indices) {
    is_feature_used_[idx] = true;
  }
166

Guolin Ke's avatar
Guolin Ke committed
167
168
169
170
  // initialize data partition
  data_partition_->Init();

  // reset the splits for leaves
Guolin Ke's avatar
Guolin Ke committed
171
  for (int i = 0; i < tree_config_->num_leaves; ++i) {
Guolin Ke's avatar
Guolin Ke committed
172
173
174
175
176
177
178
179
    best_split_per_leaf_[i].Reset();
  }

  // Sumup for root
  if (data_partition_->leaf_count(0) == num_data_) {
    // use all data
    smaller_leaf_splits_->Init(gradients_, hessians_);
    // point to gradients, avoid copy
180
181
    ptr_to_ordered_gradients_smaller_leaf_ = gradients_;
    ptr_to_ordered_hessians_smaller_leaf_  = hessians_;
Guolin Ke's avatar
Guolin Ke committed
182
183
  } else {
    // use bagging, only use part of data
Guolin Ke's avatar
Guolin Ke committed
184
    smaller_leaf_splits_->Init(0, data_partition_.get(), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
185
186
187
188
189
190
191
192
193
    // copy used gradients and hessians to ordered buffer
    const data_size_t* indices = data_partition_->indices();
    data_size_t cnt = data_partition_->leaf_count(0);
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < cnt; ++i) {
      ordered_gradients_[i] = gradients_[indices[i]];
      ordered_hessians_[i] = hessians_[indices[i]];
    }
    // point to ordered_gradients_ and ordered_hessians_
Guolin Ke's avatar
Guolin Ke committed
194
195
    ptr_to_ordered_gradients_smaller_leaf_ = ordered_gradients_.data();
    ptr_to_ordered_hessians_smaller_leaf_ = ordered_hessians_.data();
Guolin Ke's avatar
Guolin Ke committed
196
197
  }

198
199
200
  ptr_to_ordered_gradients_larger_leaf_ = nullptr;
  ptr_to_ordered_hessians_larger_leaf_ = nullptr;

Guolin Ke's avatar
Guolin Ke committed
201
202
203
204
205
206
207
208
209
  larger_leaf_splits_->Init();

  // if has ordered bin, need to initialize the ordered bin
  if (has_ordered_bin_) {
    if (data_partition_->leaf_count(0) == num_data_) {
      // use all data, pass nullptr
      #pragma omp parallel for schedule(guided)
      for (int i = 0; i < num_features_; ++i) {
        if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
210
          ordered_bins_[i]->Init(nullptr, tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
211
212
213
214
215
216
        }
      }
    } else {
      // bagging, only use part of data

      // mark used data
Guolin Ke's avatar
Guolin Ke committed
217
      std::memset(is_data_in_leaf_.data(), 0, sizeof(char)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
218
219
220
221
222
223
224
225
226
227
228
      const data_size_t* indices = data_partition_->indices();
      data_size_t begin = data_partition_->leaf_begin(0);
      data_size_t end = begin + data_partition_->leaf_count(0);
      #pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        is_data_in_leaf_[indices[i]] = 1;
      }
      // initialize ordered bin
      #pragma omp parallel for schedule(guided)
      for (int i = 0; i < num_features_; ++i) {
        if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
229
          ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
234
235
236
        }
      }
    }
  }
}

bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
Guolin Ke's avatar
Guolin Ke committed
237
  // check depth of current leaf
Guolin Ke's avatar
Guolin Ke committed
238
  if (tree_config_->max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
239
    // only need to check left leaf, since right leaf is in same level of left leaf
Guolin Ke's avatar
Guolin Ke committed
240
    if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_->max_depth) {
Guolin Ke's avatar
Guolin Ke committed
241
242
243
244
245
246
247
      best_split_per_leaf_[left_leaf].gain = kMinScore;
      if (right_leaf >= 0) {
        best_split_per_leaf_[right_leaf].gain = kMinScore;
      }
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
248
249
250
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // no enough data to continue
Guolin Ke's avatar
Guolin Ke committed
251
252
  if (num_data_in_right_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)
    && num_data_in_left_child < static_cast<data_size_t>(tree_config_->min_data_in_leaf * 2)) {
Guolin Ke's avatar
Guolin Ke committed
253
254
255
256
257
258
    best_split_per_leaf_[left_leaf].gain = kMinScore;
    if (right_leaf >= 0) {
      best_split_per_leaf_[right_leaf].gain = kMinScore;
    }
    return false;
  }
259
  parent_leaf_histogram_array_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
260
261
  // -1 if only has one leaf. else equal the index of smaller leaf
  int smaller_leaf = -1;
262
  int larger_leaf = -1;
Guolin Ke's avatar
Guolin Ke committed
263
264
  // only have root
  if (right_leaf < 0) {
265
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
266
    larger_leaf_histogram_array_ = nullptr;
267

Guolin Ke's avatar
Guolin Ke committed
268
269
  } else if (num_data_in_left_child < num_data_in_right_child) {
    smaller_leaf = left_leaf;
270
    larger_leaf = right_leaf;
Hui Xue's avatar
Hui Xue committed
271
    // put parent(left) leaf's histograms into larger leaf's histograms
272
273
274
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Move(left_leaf, right_leaf);
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
275
276
  } else {
    smaller_leaf = right_leaf;
277
    larger_leaf = left_leaf;
Hui Xue's avatar
Hui Xue committed
278
    // put parent(left) leaf's histograms to larger leaf's histograms
279
280
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Get(right_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
  }

  // init for the ordered gradients, only initialize when have 2 leaves
  if (smaller_leaf >= 0) {
    // only need to initialize for smaller leaf

    // Get leaf boundary
    const data_size_t* indices = data_partition_->indices();
    data_size_t begin = data_partition_->leaf_begin(smaller_leaf);
    data_size_t end = begin + data_partition_->leaf_count(smaller_leaf);
    // copy
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      ordered_gradients_[i - begin] = gradients_[indices[i]];
      ordered_hessians_[i - begin] = hessians_[indices[i]];
    }
    // assign pointer
Guolin Ke's avatar
Guolin Ke committed
298
299
    ptr_to_ordered_gradients_smaller_leaf_ = ordered_gradients_.data();
    ptr_to_ordered_hessians_smaller_leaf_ = ordered_hessians_.data();
300
301
302
303
304
305
306
307
308
309
310
311

    if (parent_leaf_histogram_array_ == nullptr) {
      // need order gradient for larger leaf
      data_size_t smaller_size = end - begin;
      data_size_t larger_begin = data_partition_->leaf_begin(larger_leaf);
      data_size_t larger_end = larger_begin + data_partition_->leaf_count(larger_leaf);
      // copy
      #pragma omp parallel for schedule(static)
      for (data_size_t i = larger_begin; i < larger_end; ++i) {
        ordered_gradients_[smaller_size + i - larger_begin] = gradients_[indices[i]];
        ordered_hessians_[smaller_size + i - larger_begin] = hessians_[indices[i]];
      }
Guolin Ke's avatar
Guolin Ke committed
312
313
      ptr_to_ordered_gradients_larger_leaf_ = ordered_gradients_.data() + smaller_size;
      ptr_to_ordered_hessians_larger_leaf_ = ordered_hessians_.data() + smaller_size;
314
    }
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
319
  }

  // split for the ordered bin
  if (has_ordered_bin_ && right_leaf >= 0) {
    // mark data that at left-leaf
Guolin Ke's avatar
Guolin Ke committed
320
    std::memset(is_data_in_leaf_.data(), 0, sizeof(char)*num_data_);
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324
325
326
327
328
329
330
331
    const data_size_t* indices = data_partition_->indices();
    data_size_t begin = data_partition_->leaf_begin(left_leaf);
    data_size_t end = begin + data_partition_->leaf_count(left_leaf);
    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      is_data_in_leaf_[indices[i]] = 1;
    }
    // split the ordered bin
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_features_; ++i) {
      if (ordered_bins_[i] != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
332
        ordered_bins_[i]->Split(left_leaf, right_leaf, is_data_in_leaf_.data());
Guolin Ke's avatar
Guolin Ke committed
333
334
335
336
337
338
339
340
341
342
343
      }
    }
  }
  return true;
}


void SerialTreeLearner::FindBestThresholds() {
  #pragma omp parallel for schedule(guided)
  for (int feature_index = 0; feature_index < num_features_; feature_index++) {
    // feature is not used
Guolin Ke's avatar
Guolin Ke committed
344
    if ((!is_feature_used_.empty() && is_feature_used_[feature_index] == false)) continue;
Guolin Ke's avatar
Guolin Ke committed
345
    // if parent(larger) leaf cannot split at current feature
346
    if (parent_leaf_histogram_array_ != nullptr && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
Guolin Ke's avatar
Guolin Ke committed
347
348
349
350
351
352
353
354
355
356
357
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }

    // construct histograms for smaller leaf
    if (ordered_bins_[feature_index] == nullptr) {
      // if not use ordered bin
      smaller_leaf_histogram_array_[feature_index].Construct(smaller_leaf_splits_->data_indices(),
        smaller_leaf_splits_->num_data_in_leaf(),
        smaller_leaf_splits_->sum_gradients(),
        smaller_leaf_splits_->sum_hessians(),
358
359
        ptr_to_ordered_gradients_smaller_leaf_,
        ptr_to_ordered_hessians_smaller_leaf_);
Guolin Ke's avatar
Guolin Ke committed
360
361
    } else {
      // used ordered bin
Guolin Ke's avatar
Guolin Ke committed
362
      smaller_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(),
Guolin Ke's avatar
Guolin Ke committed
363
364
365
366
367
368
369
370
371
372
373
374
375
        smaller_leaf_splits_->LeafIndex(),
        smaller_leaf_splits_->num_data_in_leaf(),
        smaller_leaf_splits_->sum_gradients(),
        smaller_leaf_splits_->sum_hessians(),
        gradients_,
        hessians_);
    }
    // find best threshold for smaller child
    smaller_leaf_histogram_array_[feature_index].FindBestThreshold(&smaller_leaf_splits_->BestSplitPerFeature()[feature_index]);

    // only has root leaf
    if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) continue;

376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
    if (parent_leaf_histogram_array_ != nullptr) {
      // construct histgroms for large leaf, we initialize larger leaf as the parent,
      // so we can just subtract the smaller leaf's histograms
      larger_leaf_histogram_array_[feature_index].Subtract(smaller_leaf_histogram_array_[feature_index]);
    } else {
      if (ordered_bins_[feature_index] == nullptr) {
        // if not use ordered bin
        larger_leaf_histogram_array_[feature_index].Construct(larger_leaf_splits_->data_indices(),
          larger_leaf_splits_->num_data_in_leaf(),
          larger_leaf_splits_->sum_gradients(),
          larger_leaf_splits_->sum_hessians(),
          ptr_to_ordered_gradients_larger_leaf_,
          ptr_to_ordered_hessians_larger_leaf_);
      } else {
        // used ordered bin
Guolin Ke's avatar
Guolin Ke committed
391
        larger_leaf_histogram_array_[feature_index].Construct(ordered_bins_[feature_index].get(),
392
393
394
395
396
397
398
399
          larger_leaf_splits_->LeafIndex(),
          larger_leaf_splits_->num_data_in_leaf(),
          larger_leaf_splits_->sum_gradients(),
          larger_leaf_splits_->sum_hessians(),
          gradients_,
          hessians_);
      }
    }
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
404
405
406
407
408
409
410
411
412

    // find best threshold for larger child
    larger_leaf_histogram_array_[feature_index].FindBestThreshold(&larger_leaf_splits_->BestSplitPerFeature()[feature_index]);
  }
}


void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
  const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];

  // left = parent
  *left_leaf = best_Leaf;
  // split tree, will return right leaf
Guolin Ke's avatar
Guolin Ke committed
413
414
415
  *right_leaf = tree->Split(best_Leaf, best_split_info.feature, 
    train_data_->FeatureAt(best_split_info.feature)->bin_type(),
    best_split_info.threshold,
Guolin Ke's avatar
Guolin Ke committed
416
417
    train_data_->FeatureAt(best_split_info.feature)->feature_index(),
    train_data_->FeatureAt(best_split_info.feature)->BinToValue(best_split_info.threshold),
418
419
    static_cast<double>(best_split_info.left_output),
    static_cast<double>(best_split_info.right_output),
Guolin Ke's avatar
Guolin Ke committed
420
421
    static_cast<data_size_t>(best_split_info.left_count),
    static_cast<data_size_t>(best_split_info.right_count),
422
    static_cast<double>(best_split_info.gain));
Guolin Ke's avatar
Guolin Ke committed
423
424
425
426
427
428
429

  // split data partition
  data_partition_->Split(best_Leaf, train_data_->FeatureAt(best_split_info.feature)->bin_data(),
                         best_split_info.threshold, *right_leaf);

  // init the leaves that used on next iteration
  if (best_split_info.left_count < best_split_info.right_count) {
Guolin Ke's avatar
Guolin Ke committed
430
    smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
431
432
                               best_split_info.left_sum_gradient,
                               best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
433
    larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
434
435
436
                               best_split_info.right_sum_gradient,
                               best_split_info.right_sum_hessian);
  } else {
Guolin Ke's avatar
Guolin Ke committed
437
438
    smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
    larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
Guolin Ke's avatar
Guolin Ke committed
439
440
441
442
  }
}

}  // namespace LightGBM