serial_tree_learner.cpp 35.3 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include "serial_tree_learner.h"

7
8
9
10
11
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/common.h>

12
13
#include <algorithm>
#include <queue>
14
#include <set>
15
16
17
#include <unordered_map>
#include <utility>

18
19
#include "cost_effective_gradient_boosting.hpp"

Guolin Ke's avatar
Guolin Ke committed
20
21
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
22
SerialTreeLearner::SerialTreeLearner(const Config* config)
23
    : config_(config), col_sampler_(config) {
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
28
}

SerialTreeLearner::~SerialTreeLearner() {
}

29
void SerialTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
Guolin Ke's avatar
Guolin Ke committed
30
31
32
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
  num_features_ = train_data_->num_features();
33
34
  int max_cache_size = 0;
  // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
35
36
  if (config_->histogram_pool_size <= 0) {
    max_cache_size = config_->num_leaves;
37
38
39
  } else {
    size_t total_histogram_size = 0;
    for (int i = 0; i < train_data_->num_features(); ++i) {
40
      total_histogram_size += kHistEntrySize * train_data_->FeatureNumBin(i);
41
    }
Guolin Ke's avatar
Guolin Ke committed
42
    max_cache_size = static_cast<int>(config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
43
44
  }
  // at least need 2 leaves
Guolin Ke's avatar
Guolin Ke committed
45
  max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
46
  max_cache_size = std::min(max_cache_size, config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
  // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
49
  best_split_per_leaf_.resize(config_->num_leaves);
50
  constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves, train_data_->num_features()));
Guolin Ke's avatar
Guolin Ke committed
51

wxchan's avatar
wxchan committed
52
  // initialize splits for leaf
Guolin Ke's avatar
Guolin Ke committed
53
54
  smaller_leaf_splits_.reset(new LeafSplits(train_data_->num_data(), config_));
  larger_leaf_splits_.reset(new LeafSplits(train_data_->num_data(), config_));
Guolin Ke's avatar
Guolin Ke committed
55
56

  // initialize data partition
Guolin Ke's avatar
Guolin Ke committed
57
  data_partition_.reset(new DataPartition(num_data_, config_->num_leaves));
58
  col_sampler_.SetTrainingData(train_data_);
Guolin Ke's avatar
Guolin Ke committed
59
  // initialize ordered gradients and hessians
Guolin Ke's avatar
Guolin Ke committed
60
61
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
62

63
  GetShareStates(train_data_, is_constant_hessian, true);
64
65
66
67
  histogram_pool_.DynamicChangeSize(train_data_,
  share_state_->num_hist_total_bin(),
  share_state_->feature_hist_offsets(),
  config_, max_cache_size, config_->num_leaves);
68
  Log::Info("Number of data points in the train set: %d, number of used features: %d", num_data_, num_features_);
69
70
71
  if (CostEfficientGradientBoosting::IsEnable(config_)) {
    cegb_.reset(new CostEfficientGradientBoosting(this));
    cegb_->Init();
72
  }
Guolin Ke's avatar
Guolin Ke committed
73
74
}

75
76
77
void SerialTreeLearner::GetShareStates(const Dataset* dataset,
                                       bool is_constant_hessian,
                                       bool is_first_time) {
78
  if (is_first_time) {
79
    share_state_.reset(dataset->GetShareStates(
80
81
82
        ordered_gradients_.data(), ordered_hessians_.data(),
        col_sampler_.is_feature_used_bytree(), is_constant_hessian,
        config_->force_col_wise, config_->force_row_wise));
83
  } else {
Nikita Titov's avatar
Nikita Titov committed
84
    CHECK_NOTNULL(share_state_);
85
    // cannot change is_hist_col_wise during training
86
    share_state_.reset(dataset->GetShareStates(
87
        ordered_gradients_.data(), ordered_hessians_.data(), col_sampler_.is_feature_used_bytree(),
88
89
        is_constant_hessian, share_state_->is_col_wise,
        !share_state_->is_col_wise));
90
  }
Nikita Titov's avatar
Nikita Titov committed
91
  CHECK_NOTNULL(share_state_);
92
93
}

94
95
96
void SerialTreeLearner::ResetTrainingDataInner(const Dataset* train_data,
                                               bool is_constant_hessian,
                                               bool reset_multi_val_bin) {
Guolin Ke's avatar
Guolin Ke committed
97
98
  train_data_ = train_data;
  num_data_ = train_data_->num_data();
99
  CHECK_EQ(num_features_, train_data_->num_features());
Guolin Ke's avatar
Guolin Ke committed
100
101
102
103
104
105
106

  // initialize splits for leaf
  smaller_leaf_splits_->ResetNumData(num_data_);
  larger_leaf_splits_->ResetNumData(num_data_);

  // initialize data partition
  data_partition_->ResetNumData(num_data_);
107
  if (reset_multi_val_bin) {
108
    col_sampler_.SetTrainingData(train_data_);
109
110
    GetShareStates(train_data_, is_constant_hessian, false);
  }
111

Guolin Ke's avatar
Guolin Ke committed
112
113
114
  // initialize ordered gradients and hessians
  ordered_gradients_.resize(num_data_);
  ordered_hessians_.resize(num_data_);
115
116
117
  if (cegb_ != nullptr) {
    cegb_->Init();
  }
Guolin Ke's avatar
Guolin Ke committed
118
}
Guolin Ke's avatar
Guolin Ke committed
119

Guolin Ke's avatar
Guolin Ke committed
120
121
122
void SerialTreeLearner::ResetConfig(const Config* config) {
  if (config_->num_leaves != config->num_leaves) {
    config_ = config;
Guolin Ke's avatar
Guolin Ke committed
123
124
    int max_cache_size = 0;
    // Get the max size of pool
Guolin Ke's avatar
Guolin Ke committed
125
126
    if (config->histogram_pool_size <= 0) {
      max_cache_size = config_->num_leaves;
Guolin Ke's avatar
Guolin Ke committed
127
128
129
    } else {
      size_t total_histogram_size = 0;
      for (int i = 0; i < train_data_->num_features(); ++i) {
130
        total_histogram_size += kHistEntrySize * train_data_->FeatureNumBin(i);
Guolin Ke's avatar
Guolin Ke committed
131
      }
Guolin Ke's avatar
Guolin Ke committed
132
      max_cache_size = static_cast<int>(config_->histogram_pool_size * 1024 * 1024 / total_histogram_size);
Guolin Ke's avatar
Guolin Ke committed
133
134
135
    }
    // at least need 2 leaves
    max_cache_size = std::max(2, max_cache_size);
Guolin Ke's avatar
Guolin Ke committed
136
    max_cache_size = std::min(max_cache_size, config_->num_leaves);
137
138
139
140
    histogram_pool_.DynamicChangeSize(train_data_,
    share_state_->num_hist_total_bin(),
    share_state_->feature_hist_offsets(),
    config_, max_cache_size, config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
141
142

    // push split information for all leaves
Guolin Ke's avatar
Guolin Ke committed
143
144
    best_split_per_leaf_.resize(config_->num_leaves);
    data_partition_->ResetLeaves(config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
145
  } else {
Guolin Ke's avatar
Guolin Ke committed
146
    config_ = config;
Guolin Ke's avatar
Guolin Ke committed
147
  }
148
  col_sampler_.SetConfig(config_);
149
  histogram_pool_.ResetConfig(train_data_, config_);
150
  if (CostEfficientGradientBoosting::IsEnable(config_)) {
151
152
153
    if (cegb_ == nullptr) {
      cegb_.reset(new CostEfficientGradientBoosting(this));
    }
154
155
    cegb_->Init();
  }
156
  constraints_.reset(LeafConstraintsBase::Create(config_, config_->num_leaves, train_data_->num_features()));
Guolin Ke's avatar
Guolin Ke committed
157
158
}

159
Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians, bool /*is_first_tree*/) {
160
  Common::FunctionTimer fun_timer("SerialTreeLearner::Train", global_timer);
Guolin Ke's avatar
Guolin Ke committed
161
162
  gradients_ = gradients;
  hessians_ = hessians;
163
  int num_threads = OMP_NUM_THREADS();
Nikita Titov's avatar
Nikita Titov committed
164
  if (share_state_->num_threads != num_threads && share_state_->num_threads > 0) {
165
    Log::Warning(
Nikita Titov's avatar
Nikita Titov committed
166
167
        "Detected that num_threads changed during training (from %d to %d), "
        "it may cause unexpected errors.",
168
169
170
        share_state_->num_threads, num_threads);
  }
  share_state_->num_threads = num_threads;
Nikita Titov's avatar
Nikita Titov committed
171

Guolin Ke's avatar
Guolin Ke committed
172
173
  // some initial works before training
  BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
174

175
  bool track_branch_features = !(config_->interaction_constraints_vector.empty());
176
  auto tree = std::unique_ptr<Tree>(new Tree(config_->num_leaves, track_branch_features, false));
Guolin Ke's avatar
Guolin Ke committed
177
178
  auto tree_ptr = tree.get();
  constraints_->ShareTreePointer(tree_ptr);
179

Guolin Ke's avatar
Guolin Ke committed
180
181
  // root leaf
  int left_leaf = 0;
182
  int cur_depth = 1;
Guolin Ke's avatar
Guolin Ke committed
183
184
  // only root leaf can be splitted on first time
  int right_leaf = -1;
185

Guolin Ke's avatar
Guolin Ke committed
186
  int init_splits = ForceSplits(tree_ptr, &left_leaf, &right_leaf, &cur_depth);
187

Guolin Ke's avatar
Guolin Ke committed
188
  for (int split = init_splits; split < config_->num_leaves - 1; ++split) {
Guolin Ke's avatar
Guolin Ke committed
189
    // some initial works before finding best split
Guolin Ke's avatar
Guolin Ke committed
190
    if (BeforeFindBestSplit(tree_ptr, left_leaf, right_leaf)) {
Guolin Ke's avatar
Guolin Ke committed
191
      // find best threshold for every feature
Guolin Ke's avatar
Guolin Ke committed
192
      FindBestSplits(tree_ptr);
193
    }
Guolin Ke's avatar
Guolin Ke committed
194
195
196
197
198
199
    // Get a leaf with max split gain
    int best_leaf = static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    // Get split information for best leaf
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    // cannot split, quit
    if (best_leaf_SplitInfo.gain <= 0.0) {
Guolin Ke's avatar
Guolin Ke committed
200
      Log::Warning("No further splits with positive gain, best gain: %f", best_leaf_SplitInfo.gain);
Guolin Ke's avatar
Guolin Ke committed
201
202
203
      break;
    }
    // split tree with best leaf
Guolin Ke's avatar
Guolin Ke committed
204
    Split(tree_ptr, best_leaf, &left_leaf, &right_leaf);
205
    cur_depth = std::max(cur_depth, tree->leaf_depth(left_leaf));
Guolin Ke's avatar
Guolin Ke committed
206
  }
207

208
  Log::Debug("Trained a tree with leaves = %d and depth = %d", tree->num_leaves(), cur_depth);
Guolin Ke's avatar
Guolin Ke committed
209
  return tree.release();
Guolin Ke's avatar
Guolin Ke committed
210
211
}

212
Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const score_t* gradients, const score_t *hessians) const {
Guolin Ke's avatar
Guolin Ke committed
213
  auto tree = std::unique_ptr<Tree>(new Tree(*old_tree));
Nikita Titov's avatar
Nikita Titov committed
214
  CHECK_GE(data_partition_->num_leaves(), tree->num_leaves());
215
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
216
  #pragma omp parallel for schedule(static)
217
  for (int i = 0; i < tree->num_leaves(); ++i) {
218
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
219
220
221
    data_size_t cnt_leaf_data = 0;
    auto tmp_idx = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
    double sum_grad = 0.0f;
222
    double sum_hess = kEpsilon;
Guolin Ke's avatar
Guolin Ke committed
223
224
225
226
227
    for (data_size_t j = 0; j < cnt_leaf_data; ++j) {
      auto idx = tmp_idx[j];
      sum_grad += gradients[idx];
      sum_hess += hessians[idx];
    }
Belinda Trotta's avatar
Belinda Trotta committed
228
229
230
231
232
233
234
235
236
237
    double output;
    if ((config_->path_smooth > kEpsilon) & (i > 0)) {
      output = FeatureHistogram::CalculateSplittedLeafOutput<true, true, true>(
          sum_grad, sum_hess, config_->lambda_l1, config_->lambda_l2,
          config_->max_delta_step, config_->path_smooth, cnt_leaf_data, tree->leaf_parent(i));
    } else {
      output = FeatureHistogram::CalculateSplittedLeafOutput<true, true, false>(
          sum_grad, sum_hess, config_->lambda_l1, config_->lambda_l2,
          config_->max_delta_step, config_->path_smooth, cnt_leaf_data, 0);
    }
Guolin Ke's avatar
Guolin Ke committed
238
239
240
    auto old_leaf_output = tree->LeafOutput(i);
    auto new_leaf_output = output * tree->shrinkage();
    tree->SetLeafOutput(i, config_->refit_decay_rate * old_leaf_output + (1.0 - config_->refit_decay_rate) * new_leaf_output);
241
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
242
  }
243
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
244
245
246
  return tree.release();
}

247
248
Tree* SerialTreeLearner::FitByExistingTree(const Tree* old_tree, const std::vector<int>& leaf_pred,
                                           const score_t* gradients, const score_t *hessians) const {
249
250
251
252
  data_partition_->ResetByLeafPred(leaf_pred, old_tree->num_leaves());
  return FitByExistingTree(old_tree, gradients, hessians);
}

Guolin Ke's avatar
Guolin Ke committed
253
void SerialTreeLearner::BeforeTrain() {
254
  Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeTrain", global_timer);
255
256
  // reset histogram pool
  histogram_pool_.ResetMap();
Guolin Ke's avatar
Guolin Ke committed
257

258
259
  col_sampler_.ResetByTree();
  train_data_->InitTrain(col_sampler_.is_feature_used_bytree(), share_state_.get());
Guolin Ke's avatar
Guolin Ke committed
260
261
262
  // initialize data partition
  data_partition_->Init();

263
264
  constraints_->Reset();

Guolin Ke's avatar
Guolin Ke committed
265
  // reset the splits for leaves
Guolin Ke's avatar
Guolin Ke committed
266
  for (int i = 0; i < config_->num_leaves; ++i) {
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
271
272
273
    best_split_per_leaf_[i].Reset();
  }

  // Sumup for root
  if (data_partition_->leaf_count(0) == num_data_) {
    // use all data
    smaller_leaf_splits_->Init(gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
274

Guolin Ke's avatar
Guolin Ke committed
275
276
  } else {
    // use bagging, only use part of data
Guolin Ke's avatar
Guolin Ke committed
277
    smaller_leaf_splits_->Init(0, data_partition_.get(), gradients_, hessians_);
Guolin Ke's avatar
Guolin Ke committed
278
279
280
  }

  larger_leaf_splits_->Init();
281
282
283
284

  if (cegb_ != nullptr) {
    cegb_->BeforeTrain();
  }
Guolin Ke's avatar
Guolin Ke committed
285
286
}

Guolin Ke's avatar
Guolin Ke committed
287
bool SerialTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
288
  Common::FunctionTimer fun_timer("SerialTreeLearner::BeforeFindBestSplit", global_timer);
Guolin Ke's avatar
Guolin Ke committed
289
  // check depth of current leaf
Guolin Ke's avatar
Guolin Ke committed
290
  if (config_->max_depth > 0) {
Guolin Ke's avatar
Guolin Ke committed
291
    // only need to check left leaf, since right leaf is in same level of left leaf
Guolin Ke's avatar
Guolin Ke committed
292
    if (tree->leaf_depth(left_leaf) >= config_->max_depth) {
Guolin Ke's avatar
Guolin Ke committed
293
294
295
296
297
298
299
      best_split_per_leaf_[left_leaf].gain = kMinScore;
      if (right_leaf >= 0) {
        best_split_per_leaf_[right_leaf].gain = kMinScore;
      }
      return false;
    }
  }
Guolin Ke's avatar
Guolin Ke committed
300
301
302
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // no enough data to continue
Guolin Ke's avatar
Guolin Ke committed
303
304
  if (num_data_in_right_child < static_cast<data_size_t>(config_->min_data_in_leaf * 2)
      && num_data_in_left_child < static_cast<data_size_t>(config_->min_data_in_leaf * 2)) {
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308
309
310
    best_split_per_leaf_[left_leaf].gain = kMinScore;
    if (right_leaf >= 0) {
      best_split_per_leaf_[right_leaf].gain = kMinScore;
    }
    return false;
  }
311
  parent_leaf_histogram_array_ = nullptr;
Guolin Ke's avatar
Guolin Ke committed
312
313
  // only have root
  if (right_leaf < 0) {
314
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
315
316
    larger_leaf_histogram_array_ = nullptr;
  } else if (num_data_in_left_child < num_data_in_right_child) {
Hui Xue's avatar
Hui Xue committed
317
    // put parent(left) leaf's histograms into larger leaf's histograms
318
319
320
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Move(left_leaf, right_leaf);
    histogram_pool_.Get(left_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
321
  } else {
Hui Xue's avatar
Hui Xue committed
322
    // put parent(left) leaf's histograms to larger leaf's histograms
323
324
    if (histogram_pool_.Get(left_leaf, &larger_leaf_histogram_array_)) { parent_leaf_histogram_array_ = larger_leaf_histogram_array_; }
    histogram_pool_.Get(right_leaf, &smaller_leaf_histogram_array_);
Guolin Ke's avatar
Guolin Ke committed
325
326
327
328
  }
  return true;
}

329
void SerialTreeLearner::FindBestSplits(const Tree* tree) {
330
331
332
333
  FindBestSplits(tree, nullptr);
}

void SerialTreeLearner::FindBestSplits(const Tree* tree, const std::set<int>* force_features) {
Guolin Ke's avatar
Guolin Ke committed
334
  std::vector<int8_t> is_feature_used(num_features_, 0);
335
  #pragma omp parallel for schedule(static, 256) if (num_features_ >= 512)
Guolin Ke's avatar
Guolin Ke committed
336
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
337
    if (!col_sampler_.is_feature_used_bytree()[feature_index] && (force_features == nullptr || force_features->find(feature_index) == force_features->end())) continue;
Guolin Ke's avatar
Guolin Ke committed
338
339
340
341
342
343
344
345
    if (parent_leaf_histogram_array_ != nullptr
        && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }
    is_feature_used[feature_index] = 1;
  }
  bool use_subtract = parent_leaf_histogram_array_ != nullptr;
346

Guolin Ke's avatar
Guolin Ke committed
347
  ConstructHistograms(is_feature_used, use_subtract);
348
  FindBestSplitsFromHistograms(is_feature_used, use_subtract, tree);
Guolin Ke's avatar
Guolin Ke committed
349
350
}

351
352
353
354
void SerialTreeLearner::ConstructHistograms(
    const std::vector<int8_t>& is_feature_used, bool use_subtract) {
  Common::FunctionTimer fun_timer("SerialTreeLearner::ConstructHistograms",
                                  global_timer);
Guolin Ke's avatar
Guolin Ke committed
355
  // construct smaller leaf
356
357
  hist_t* ptr_smaller_leaf_hist_data =
      smaller_leaf_histogram_array_[0].RawData() - kHistOffset;
358
359
360
  train_data_->ConstructHistograms(
      is_feature_used, smaller_leaf_splits_->data_indices(),
      smaller_leaf_splits_->num_data_in_leaf(), gradients_, hessians_,
361
362
      ordered_gradients_.data(), ordered_hessians_.data(), share_state_.get(),
      ptr_smaller_leaf_hist_data);
Guolin Ke's avatar
Guolin Ke committed
363
364
  if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
    // construct larger leaf
365
366
    hist_t* ptr_larger_leaf_hist_data =
        larger_leaf_histogram_array_[0].RawData() - kHistOffset;
367
368
369
    train_data_->ConstructHistograms(
        is_feature_used, larger_leaf_splits_->data_indices(),
        larger_leaf_splits_->num_data_in_leaf(), gradients_, hessians_,
370
        ordered_gradients_.data(), ordered_hessians_.data(), share_state_.get(),
371
        ptr_larger_leaf_hist_data);
Guolin Ke's avatar
Guolin Ke committed
372
  }
373
374
}

Guolin Ke's avatar
Guolin Ke committed
375
void SerialTreeLearner::FindBestSplitsFromHistograms(
376
    const std::vector<int8_t>& is_feature_used, bool use_subtract, const Tree* tree) {
Guolin Ke's avatar
Guolin Ke committed
377
378
  Common::FunctionTimer fun_timer(
      "SerialTreeLearner::FindBestSplitsFromHistograms", global_timer);
379
380
  std::vector<SplitInfo> smaller_best(share_state_->num_threads);
  std::vector<SplitInfo> larger_best(share_state_->num_threads);
381
382
  std::vector<int8_t> smaller_node_used_features = col_sampler_.GetByNode(tree, smaller_leaf_splits_->leaf_index());
  std::vector<int8_t> larger_node_used_features;
383
384
385
386
387
  double smaller_leaf_parent_output = GetParentOutput(tree, smaller_leaf_splits_.get());
  double larger_leaf_parent_output = 0;
  if (larger_leaf_splits_ != nullptr && larger_leaf_splits_->leaf_index() >= 0) {
    larger_leaf_parent_output = GetParentOutput(tree, larger_leaf_splits_.get());
  }
388
389
390
  if (larger_leaf_splits_->leaf_index() >= 0) {
    larger_node_used_features = col_sampler_.GetByNode(tree, larger_leaf_splits_->leaf_index());
  }
391
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
392
// find splits
393
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
Guolin Ke's avatar
Guolin Ke committed
394
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
395
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
396
397
398
    if (!is_feature_used[feature_index]) {
      continue;
    }
Guolin Ke's avatar
Guolin Ke committed
399
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
400
401
402
403
    train_data_->FixHistogram(
        feature_index, smaller_leaf_splits_->sum_gradients(),
        smaller_leaf_splits_->sum_hessians(),
        smaller_leaf_histogram_array_[feature_index].RawData());
404
    int real_fidx = train_data_->RealFeatureIndex(feature_index);
405
406
407
408
409

    ComputeBestSplitForFeature(smaller_leaf_histogram_array_, feature_index,
                               real_fidx,
                               smaller_node_used_features[feature_index],
                               smaller_leaf_splits_->num_data_in_leaf(),
410
411
                               smaller_leaf_splits_.get(), &smaller_best[tid],
                               smaller_leaf_parent_output);
412

Guolin Ke's avatar
Guolin Ke committed
413
    // only has root leaf
Guolin Ke's avatar
Guolin Ke committed
414
415
416
417
    if (larger_leaf_splits_ == nullptr ||
        larger_leaf_splits_->leaf_index() < 0) {
      continue;
    }
Guolin Ke's avatar
Guolin Ke committed
418

Guolin Ke's avatar
Guolin Ke committed
419
    if (use_subtract) {
Guolin Ke's avatar
Guolin Ke committed
420
421
      larger_leaf_histogram_array_[feature_index].Subtract(
          smaller_leaf_histogram_array_[feature_index]);
422
    } else {
Guolin Ke's avatar
Guolin Ke committed
423
424
425
426
      train_data_->FixHistogram(
          feature_index, larger_leaf_splits_->sum_gradients(),
          larger_leaf_splits_->sum_hessians(),
          larger_leaf_histogram_array_[feature_index].RawData());
427
    }
428
429
430
431
432

    ComputeBestSplitForFeature(larger_leaf_histogram_array_, feature_index,
                               real_fidx,
                               larger_node_used_features[feature_index],
                               larger_leaf_splits_->num_data_in_leaf(),
433
434
                               larger_leaf_splits_.get(), &larger_best[tid],
                               larger_leaf_parent_output);
435

436
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
437
  }
438
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
439
  auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_best);
440
  int leaf = smaller_leaf_splits_->leaf_index();
Guolin Ke's avatar
Guolin Ke committed
441
442
  best_split_per_leaf_[leaf] = smaller_best[smaller_best_idx];

Guolin Ke's avatar
Guolin Ke committed
443
444
  if (larger_leaf_splits_ != nullptr &&
      larger_leaf_splits_->leaf_index() >= 0) {
445
    leaf = larger_leaf_splits_->leaf_index();
Guolin Ke's avatar
Guolin Ke committed
446
447
448
449
450
    auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_best);
    best_split_per_leaf_[leaf] = larger_best[larger_best_idx];
  }
}

451
452
453
454
455
456
int32_t SerialTreeLearner::ForceSplits(Tree* tree, int* left_leaf,
                                       int* right_leaf, int *cur_depth) {
  bool abort_last_forced_split = false;
  if (forced_split_json_ == nullptr) {
    return 0;
  }
457
458
459
460
  int32_t result_count = 0;
  // start at root leaf
  *left_leaf = 0;
  std::queue<std::pair<Json, int>> q;
461
  Json left = *forced_split_json_;
462
463
464
  Json right;
  bool left_smaller = true;
  std::unordered_map<int, SplitInfo> forceSplitMap;
465
  q.push(std::make_pair(left, *left_leaf));
466
467
468

  // Histogram construction require parent features.
  std::set<int> force_split_features = FindAllForceFeatures(*forced_split_json_);
469
  while (!q.empty()) {
470
    if (BeforeFindBestSplit(tree, *left_leaf, *right_leaf)) {
471
      FindBestSplits(tree, &force_split_features);
472
    }
473

474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
    // then, compute own splits
    SplitInfo left_split;
    SplitInfo right_split;

    if (!left.is_null()) {
      const int left_feature = left["feature"].int_value();
      const double left_threshold_double = left["threshold"].number_value();
      const int left_inner_feature_index = train_data_->InnerFeatureIndex(left_feature);
      const uint32_t left_threshold = train_data_->BinThreshold(
              left_inner_feature_index, left_threshold_double);
      auto leaf_histogram_array = (left_smaller) ? smaller_leaf_histogram_array_ : larger_leaf_histogram_array_;
      auto left_leaf_splits = (left_smaller) ? smaller_leaf_splits_.get() : larger_leaf_splits_.get();
      leaf_histogram_array[left_inner_feature_index].GatherInfoForThreshold(
              left_leaf_splits->sum_gradients(),
              left_leaf_splits->sum_hessians(),
              left_threshold,
              left_leaf_splits->num_data_in_leaf(),
Belinda Trotta's avatar
Belinda Trotta committed
491
              left_leaf_splits->weight(),
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
              &left_split);
      left_split.feature = left_feature;
      forceSplitMap[*left_leaf] = left_split;
      if (left_split.gain < 0) {
        forceSplitMap.erase(*left_leaf);
      }
    }

    if (!right.is_null()) {
      const int right_feature = right["feature"].int_value();
      const double right_threshold_double = right["threshold"].number_value();
      const int right_inner_feature_index = train_data_->InnerFeatureIndex(right_feature);
      const uint32_t right_threshold = train_data_->BinThreshold(
              right_inner_feature_index, right_threshold_double);
      auto leaf_histogram_array = (left_smaller) ? larger_leaf_histogram_array_ : smaller_leaf_histogram_array_;
      auto right_leaf_splits = (left_smaller) ? larger_leaf_splits_.get() : smaller_leaf_splits_.get();
      leaf_histogram_array[right_inner_feature_index].GatherInfoForThreshold(
        right_leaf_splits->sum_gradients(),
        right_leaf_splits->sum_hessians(),
        right_threshold,
        right_leaf_splits->num_data_in_leaf(),
Belinda Trotta's avatar
Belinda Trotta committed
513
        right_leaf_splits->weight(),
514
515
516
517
518
519
520
521
522
523
524
525
526
        &right_split);
      right_split.feature = right_feature;
      forceSplitMap[*right_leaf] = right_split;
      if (right_split.gain < 0) {
        forceSplitMap.erase(*right_leaf);
      }
    }

    std::pair<Json, int> pair = q.front();
    q.pop();
    int current_leaf = pair.second;
    // split info should exist because searching in bfs fashion - should have added from parent
    if (forceSplitMap.find(current_leaf) == forceSplitMap.end()) {
527
        abort_last_forced_split = true;
528
529
        break;
    }
530
531
532
533
    best_split_per_leaf_[current_leaf] = forceSplitMap[current_leaf];
    Split(tree, current_leaf, left_leaf, right_leaf);
    left_smaller = best_split_per_leaf_[current_leaf].left_count <
                   best_split_per_leaf_[current_leaf].right_count;
534
535
536
537
    left = Json();
    right = Json();
    if ((pair.first).object_items().count("left") > 0) {
      left = (pair.first)["left"];
538
539
540
      if (left.object_items().count("feature") > 0 && left.object_items().count("threshold") > 0) {
        q.push(std::make_pair(left, *left_leaf));
      }
541
542
543
    }
    if ((pair.first).object_items().count("right") > 0) {
      right = (pair.first)["right"];
544
545
546
      if (right.object_items().count("feature") > 0 && right.object_items().count("threshold") > 0) {
        q.push(std::make_pair(right, *right_leaf));
      }
547
548
549
550
    }
    result_count++;
    *(cur_depth) = std::max(*(cur_depth), tree->leaf_depth(*left_leaf));
  }
551
552
553
554
555
556
557
558
559
560
561
  if (abort_last_forced_split) {
    int best_leaf =
        static_cast<int>(ArrayArgs<SplitInfo>::ArgMax(best_split_per_leaf_));
    const SplitInfo& best_leaf_SplitInfo = best_split_per_leaf_[best_leaf];
    if (best_leaf_SplitInfo.gain <= 0.0) {
      Log::Warning("No further splits with positive gain, best gain: %f",
                   best_leaf_SplitInfo.gain);
      return config_->num_leaves;
    }
    Split(tree, best_leaf, left_leaf, right_leaf);
    *(cur_depth) = std::max(*(cur_depth), tree->leaf_depth(*left_leaf));
562
    ++result_count;
563
  }
564
565
  return result_count;
}
Guolin Ke's avatar
Guolin Ke committed
566

567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
std::set<int> SerialTreeLearner::FindAllForceFeatures(Json force_split_leaf_setting) {
  std::set<int> force_features;
  std::queue<Json> force_split_leafs;

  force_split_leafs.push(force_split_leaf_setting);

  while (!force_split_leafs.empty()) {
    Json split_leaf = force_split_leafs.front();
    force_split_leafs.pop();

    const int feature_index = split_leaf["feature"].int_value();
    const int feature_inner_index = train_data_->InnerFeatureIndex(feature_index);
    force_features.insert(feature_inner_index);

    if (split_leaf.object_items().count("left") > 0) {
      force_split_leafs.push(split_leaf["left"]);
    }

    if (split_leaf.object_items().count("right") > 0) {
      force_split_leafs.push(split_leaf["right"]);
    }
  }

  return force_features;
}

593
594
595
void SerialTreeLearner::SplitInner(Tree* tree, int best_leaf, int* left_leaf,
                                   int* right_leaf, bool update_cnt) {
  Common::FunctionTimer fun_timer("SerialTreeLearner::SplitInner", global_timer);
596
  SplitInfo& best_split_info = best_split_per_leaf_[best_leaf];
597
598
  const int inner_feature_index =
      train_data_->InnerFeatureIndex(best_split_info.feature);
599
  if (cegb_ != nullptr) {
600
601
    cegb_->UpdateLeafBestSplits(tree, best_leaf, &best_split_info,
                                &best_split_per_leaf_);
602
  }
603
  *left_leaf = best_leaf;
604
605
  auto next_leaf_id = tree->NextLeafId();

606
  // update before tree split
607
  constraints_->BeforeSplit(best_leaf, next_leaf_id,
608
609
                            best_split_info.monotone_type);

610
611
612
  bool is_numerical_split =
      train_data_->FeatureBinMapper(inner_feature_index)->bin_type() ==
      BinType::NumericalBin;
Guolin Ke's avatar
Guolin Ke committed
613
  if (is_numerical_split) {
614
615
    auto threshold_double = train_data_->RealThreshold(
        inner_feature_index, best_split_info.threshold);
616
    data_partition_->Split(best_leaf, train_data_, inner_feature_index,
617
618
619
620
621
622
623
                           &best_split_info.threshold, 1,
                           best_split_info.default_left, next_leaf_id);
    if (update_cnt) {
      // don't need to update this in data-based parallel model
      best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
      best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
    }
624
    // split tree, will return right leaf
625
626
627
628
629
630
631
632
633
    *right_leaf = tree->Split(
        best_leaf, inner_feature_index, best_split_info.feature,
        best_split_info.threshold, threshold_double,
        static_cast<double>(best_split_info.left_output),
        static_cast<double>(best_split_info.right_output),
        static_cast<data_size_t>(best_split_info.left_count),
        static_cast<data_size_t>(best_split_info.right_count),
        static_cast<double>(best_split_info.left_sum_hessian),
        static_cast<double>(best_split_info.right_sum_hessian),
634
635
        // store the true split gain in tree model
        static_cast<float>(best_split_info.gain + config_->min_gain_to_split),
636
637
        train_data_->FeatureBinMapper(inner_feature_index)->missing_type(),
        best_split_info.default_left);
638
  } else {
639
640
641
    std::vector<uint32_t> cat_bitset_inner =
        Common::ConstructBitset(best_split_info.cat_threshold.data(),
                                best_split_info.num_cat_threshold);
642
643
    std::vector<int> threshold_int(best_split_info.num_cat_threshold);
    for (int i = 0; i < best_split_info.num_cat_threshold; ++i) {
644
645
      threshold_int[i] = static_cast<int>(train_data_->RealThreshold(
          inner_feature_index, best_split_info.cat_threshold[i]));
646
    }
647
648
    std::vector<uint32_t> cat_bitset = Common::ConstructBitset(
        threshold_int.data(), best_split_info.num_cat_threshold);
649

650
    data_partition_->Split(best_leaf, train_data_, inner_feature_index,
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
                           cat_bitset_inner.data(),
                           static_cast<int>(cat_bitset_inner.size()),
                           best_split_info.default_left, next_leaf_id);

    if (update_cnt) {
      // don't need to update this in data-based parallel model
      best_split_info.left_count = data_partition_->leaf_count(*left_leaf);
      best_split_info.right_count = data_partition_->leaf_count(next_leaf_id);
    }

    *right_leaf = tree->SplitCategorical(
        best_leaf, inner_feature_index, best_split_info.feature,
        cat_bitset_inner.data(), static_cast<int>(cat_bitset_inner.size()),
        cat_bitset.data(), static_cast<int>(cat_bitset.size()),
        static_cast<double>(best_split_info.left_output),
        static_cast<double>(best_split_info.right_output),
        static_cast<data_size_t>(best_split_info.left_count),
        static_cast<data_size_t>(best_split_info.right_count),
        static_cast<double>(best_split_info.left_sum_hessian),
        static_cast<double>(best_split_info.right_sum_hessian),
671
672
        // store the true split gain in tree model
        static_cast<float>(best_split_info.gain + config_->min_gain_to_split),
673
        train_data_->FeatureBinMapper(inner_feature_index)->missing_type());
674
  }
675

676
#ifdef DEBUG
677
  CHECK(*right_leaf == next_leaf_id);
678
#endif
679

Guolin Ke's avatar
Guolin Ke committed
680
681
  // init the leaves that used on next iteration
  if (best_split_info.left_count < best_split_info.right_count) {
682
    CHECK_GT(best_split_info.left_count, 0);
683
684
    smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(),
                               best_split_info.left_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
685
686
                               best_split_info.left_sum_hessian,
                               best_split_info.left_output);
687
688
    larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
                              best_split_info.right_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
689
690
                              best_split_info.right_sum_hessian,
                              best_split_info.right_output);
Guolin Ke's avatar
Guolin Ke committed
691
  } else {
692
    CHECK_GT(best_split_info.right_count, 0);
693
694
    smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(),
                               best_split_info.right_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
695
696
                               best_split_info.right_sum_hessian,
                               best_split_info.right_output);
697
698
    larger_leaf_splits_->Init(*left_leaf, data_partition_.get(),
                              best_split_info.left_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
699
700
                              best_split_info.left_sum_hessian,
                              best_split_info.left_output);
Guolin Ke's avatar
Guolin Ke committed
701
  }
702
  auto leaves_need_update = constraints_->Update(
703
      is_numerical_split, *left_leaf, *right_leaf,
704
705
706
707
708
      best_split_info.monotone_type, best_split_info.right_output,
      best_split_info.left_output, inner_feature_index, best_split_info,
      best_split_per_leaf_);
  // update leave outputs if needed
  for (auto leaf : leaves_need_update) {
709
    RecomputeBestSplitForLeaf(tree, leaf, &best_split_per_leaf_[leaf]);
710
  }
711
}
Guolin Ke's avatar
Guolin Ke committed
712

713
void SerialTreeLearner::RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, std::function<double(const label_t*, int)> residual_getter,
714
                                        data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt, const double* /*train_score*/) const {
715
  if (obj != nullptr && obj->IsRenewTreeOutput()) {
Nikita Titov's avatar
Nikita Titov committed
716
    CHECK_LE(tree->num_leaves(), data_partition_->num_leaves());
717
718
    const data_size_t* bag_mapper = nullptr;
    if (total_num_data != num_data_) {
719
      CHECK_EQ(bag_cnt, num_data_);
720
721
      bag_mapper = bag_indices;
    }
Guolin Ke's avatar
Guolin Ke committed
722
    std::vector<int> n_nozeroworker_perleaf(tree->num_leaves(), 1);
723
    int num_machines = Network::num_machines();
724
725
726
727
728
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < tree->num_leaves(); ++i) {
      const double output = static_cast<double>(tree->LeafOutput(i));
      data_size_t cnt_leaf_data = 0;
      auto index_mapper = data_partition_->GetIndexOnLeaf(i, &cnt_leaf_data);
Guolin Ke's avatar
Guolin Ke committed
729
730
      if (cnt_leaf_data > 0) {
        // bag_mapper[index_mapper[i]]
731
        const double new_output = obj->RenewTreeOutput(output, residual_getter, index_mapper, bag_mapper, cnt_leaf_data);
Guolin Ke's avatar
Guolin Ke committed
732
733
        tree->SetLeafOutput(i, new_output);
      } else {
734
        CHECK_GT(num_machines, 1);
Guolin Ke's avatar
Guolin Ke committed
735
736
737
738
739
740
741
742
743
        tree->SetLeafOutput(i, 0.0);
        n_nozeroworker_perleaf[i] = 0;
      }
    }
    if (num_machines > 1) {
      std::vector<double> outputs(tree->num_leaves());
      for (int i = 0; i < tree->num_leaves(); ++i) {
        outputs[i] = static_cast<double>(tree->LeafOutput(i));
      }
Guolin Ke's avatar
Guolin Ke committed
744
745
      outputs = Network::GlobalSum(&outputs);
      n_nozeroworker_perleaf = Network::GlobalSum(&n_nozeroworker_perleaf);
Guolin Ke's avatar
Guolin Ke committed
746
747
748
749
750
751
752
      for (int i = 0; i < tree->num_leaves(); ++i) {
        tree->SetLeafOutput(i, outputs[i] / n_nozeroworker_perleaf[i]);
      }
    }
  }
}

753
754
void SerialTreeLearner::ComputeBestSplitForFeature(
    FeatureHistogram* histogram_array_, int feature_index, int real_fidx,
Guolin Ke's avatar
Guolin Ke committed
755
    int8_t is_feature_used, int num_data, const LeafSplits* leaf_splits,
756
    SplitInfo* best_split, double parent_output) {
757
758
759
760
761
762
763
  bool is_feature_numerical = train_data_->FeatureBinMapper(feature_index)
                                  ->bin_type() == BinType::NumericalBin;
  if (is_feature_numerical & !config_->monotone_constraints.empty()) {
    constraints_->RecomputeConstraintsIfNeeded(
        constraints_.get(), feature_index, ~(leaf_splits->leaf_index()),
        train_data_->FeatureNumBin(feature_index));
  }
764
765
766
  SplitInfo new_split;
  histogram_array_[feature_index].FindBestThreshold(
      leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), num_data,
767
      constraints_->GetFeatureConstraint(leaf_splits->leaf_index(), feature_index), parent_output, &new_split);
768
769
770
  new_split.feature = real_fidx;
  if (cegb_ != nullptr) {
    new_split.gain -=
771
        cegb_->DeltaGain(feature_index, real_fidx, leaf_splits->leaf_index(),
772
773
                         num_data, new_split);
  }
774
775
776
777
778
  if (new_split.monotone_type != 0) {
    double penalty = constraints_->ComputeMonotoneSplitGainPenalty(
        leaf_splits->leaf_index(), config_->monotone_penalty);
    new_split.gain *= penalty;
  }
Guolin Ke's avatar
Guolin Ke committed
779
780
781
  // it is needed to filter the features after the above code.
  // Otherwise, the `is_splittable` in `FeatureHistogram` will be wrong, and cause some features being accidentally filtered in the later nodes.
  if (new_split > *best_split && is_feature_used) {
782
783
784
785
    *best_split = new_split;
  }
}

786
787
788
789
790
791
792
793
794
795
796
797
798
799
double SerialTreeLearner::GetParentOutput(const Tree* tree, const LeafSplits* leaf_splits) const {
  double parent_output;
  if (tree->num_leaves() == 1) {
    // for root leaf the "parent" output is its own output because we don't apply any smoothing to the root
    parent_output = FeatureHistogram::CalculateSplittedLeafOutput<true, true, true, false>(
      leaf_splits->sum_gradients(), leaf_splits->sum_hessians(), config_->lambda_l1,
      config_->lambda_l2, config_->max_delta_step, BasicConstraint(),
      config_->path_smooth, static_cast<data_size_t>(leaf_splits->num_data_in_leaf()), 0);
  } else {
    parent_output = leaf_splits->weight();
  }
  return parent_output;
}

800
void SerialTreeLearner::RecomputeBestSplitForLeaf(Tree* tree, int leaf, SplitInfo* split) {
801
802
803
804
805
806
807
808
809
810
811
812
813
  FeatureHistogram* histogram_array_;
  if (!histogram_pool_.Get(leaf, &histogram_array_)) {
    Log::Warning(
        "Get historical Histogram for leaf %d failed, will skip the "
        "``RecomputeBestSplitForLeaf``",
        leaf);
    return;
  }
  double sum_gradients = split->left_sum_gradient + split->right_sum_gradient;
  double sum_hessians = split->left_sum_hessian + split->right_sum_hessian;
  int num_data = split->left_count + split->right_count;

  std::vector<SplitInfo> bests(share_state_->num_threads);
Guolin Ke's avatar
Guolin Ke committed
814
  LeafSplits leaf_splits(num_data, config_);
815
816
  leaf_splits.Init(leaf, sum_gradients, sum_hessians);

817
818
819
820
821
822
823
824
  // can't use GetParentOutput because leaf_splits doesn't have weight property set
  double parent_output = 0;
  if (config_->path_smooth > kEpsilon) {
    parent_output = FeatureHistogram::CalculateSplittedLeafOutput<true, true, true, false>(
      sum_gradients, sum_hessians, config_->lambda_l1, config_->lambda_l2, config_->max_delta_step,
      BasicConstraint(), config_->path_smooth, static_cast<data_size_t>(num_data), 0);
  }

825
826
  OMP_INIT_EX();
// find splits
827
std::vector<int8_t> node_used_features = col_sampler_.GetByNode(tree, leaf);
828
829
830
831
832
833
834
835
836
#pragma omp parallel for schedule(static) num_threads(share_state_->num_threads)
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
    OMP_LOOP_EX_BEGIN();
    if (!col_sampler_.is_feature_used_bytree()[feature_index] ||
        !histogram_array_[feature_index].is_splittable()) {
      continue;
    }
    const int tid = omp_get_thread_num();
    int real_fidx = train_data_->RealFeatureIndex(feature_index);
837
    ComputeBestSplitForFeature(histogram_array_, feature_index, real_fidx, node_used_features[feature_index],
838
                               num_data, &leaf_splits, &bests[tid], parent_output);
839
840
841
842
843
844
845
846

    OMP_LOOP_EX_END();
  }
  OMP_THROW_EX();
  auto best_idx = ArrayArgs<SplitInfo>::ArgMax(bests);
  *split = bests[best_idx];
}

Guolin Ke's avatar
Guolin Ke committed
847
}  // namespace LightGBM