voting_parallel_tree_learner.cpp 21.7 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
#include <LightGBM/utils/common.h>
Guolin Ke's avatar
Guolin Ke committed
6
7
8
9
10

#include <cstring>
#include <tuple>
#include <vector>

11
12
#include "parallel_tree_learner.h"

Guolin Ke's avatar
Guolin Ke committed
13
14
namespace LightGBM {

15
template <typename TREELEARNER_T>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
VotingParallelTreeLearner<TREELEARNER_T>::VotingParallelTreeLearner(const Config* config)
  :TREELEARNER_T(config) {
  top_k_ = this->config_->top_k;
Guolin Ke's avatar
Guolin Ke committed
19
20
}

21
22
23
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, bool is_constant_hessian) {
  TREELEARNER_T::Init(train_data, is_constant_hessian);
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
  rank_ = Network::rank();
  num_machines_ = Network::num_machines();

  // limit top k
28
29
  if (top_k_ > this->num_features_) {
    top_k_ = this->num_features_;
Guolin Ke's avatar
Guolin Ke committed
30
31
32
  }
  // get max bin
  int max_bin = 0;
33
34
35
  for (int i = 0; i < this->num_features_; ++i) {
    if (max_bin < this->train_data_->FeatureNumBin(i)) {
      max_bin = this->train_data_->FeatureNumBin(i);
Guolin Ke's avatar
Guolin Ke committed
36
37
38
    }
  }
  // calculate buffer size
39
  size_t buffer_size = 2 * top_k_ * std::max(max_bin * kHistEntrySize, sizeof(LightSplitInfo) * num_machines_);
40
41
42
43
  auto max_cat_threshold = this->config_->max_cat_threshold;
  // need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
  size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
  buffer_size = std::max(buffer_size, split_info_size);
Guolin Ke's avatar
Guolin Ke committed
44
45
46
47
  // left and right on same time, so need double size
  input_buffer_.resize(buffer_size);
  output_buffer_.resize(buffer_size);

48
49
  smaller_is_feature_aggregated_.resize(this->num_features_);
  larger_is_feature_aggregated_.resize(this->num_features_);
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53

  block_start_.resize(num_machines_);
  block_len_.resize(num_machines_);

54
55
  smaller_buffer_read_start_pos_.resize(this->num_features_);
  larger_buffer_read_start_pos_.resize(this->num_features_);
Guolin Ke's avatar
Guolin Ke committed
56
  global_data_count_in_leaf_.resize(this->config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
57

Guolin Ke's avatar
Guolin Ke committed
58
59
  smaller_leaf_splits_global_.reset(new LeafSplits(train_data->num_data(), this->config_));
  larger_leaf_splits_global_.reset(new LeafSplits(train_data->num_data(), this->config_));
Guolin Ke's avatar
Guolin Ke committed
60

Guolin Ke's avatar
Guolin Ke committed
61
62
63
  local_config_ = *this->config_;
  local_config_.min_data_in_leaf /= num_machines_;
  local_config_.min_sum_hessian_in_leaf /= num_machines_;
Guolin Ke's avatar
Guolin Ke committed
64

65
  this->histogram_pool_.ResetConfig(train_data, &local_config_);
Guolin Ke's avatar
Guolin Ke committed
66
67

  // initialize histograms for global
68
69
  smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
  larger_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
70
71
  std::vector<uint32_t> offsets = this->share_state_->feature_hist_offsets();
  int num_total_bin = this->share_state_->num_hist_total_bin();
72
73
  smaller_leaf_histogram_data_.resize(num_total_bin * 2);
  larger_leaf_histogram_data_.resize(num_total_bin * 2);
74
  HistogramPool::SetFeatureInfo<true, true>(train_data, this->config_, &feature_metas_);
Guolin Ke's avatar
Guolin Ke committed
75
  for (int j = 0; j < train_data->num_features(); ++j) {
76
77
    smaller_leaf_histogram_array_global_[j].Init(smaller_leaf_histogram_data_.data() + offsets[j] * 2, &feature_metas_[j]);
    larger_leaf_histogram_array_global_[j].Init(larger_leaf_histogram_data_.data() + offsets[j] * 2, &feature_metas_[j]);
Guolin Ke's avatar
Guolin Ke committed
78
79
80
  }
}

81
template <typename TREELEARNER_T>
Guolin Ke's avatar
Guolin Ke committed
82
83
void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
  TREELEARNER_T::ResetConfig(config);
Guolin Ke's avatar
Guolin Ke committed
84

Guolin Ke's avatar
Guolin Ke committed
85
86
87
  local_config_ = *this->config_;
  local_config_.min_data_in_leaf /= num_machines_;
  local_config_.min_sum_hessian_in_leaf /= num_machines_;
Guolin Ke's avatar
Guolin Ke committed
88

89
  this->histogram_pool_.ResetConfig(this->train_data_, &local_config_);
Guolin Ke's avatar
Guolin Ke committed
90
  global_data_count_in_leaf_.resize(this->config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
91

92
  HistogramPool::SetFeatureInfo<false, true>(this->train_data_, config, &feature_metas_);
Guolin Ke's avatar
Guolin Ke committed
93
}
Guolin Ke's avatar
Guolin Ke committed
94

95
96
97
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
  TREELEARNER_T::BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
98
  // sync global data sumup info
99
  std::tuple<data_size_t, double, double> data(this->smaller_leaf_splits_->num_data_in_leaf(), this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians());
Guolin Ke's avatar
Guolin Ke committed
100
101
102
  int size = sizeof(std::tuple<data_size_t, double, double>);
  std::memcpy(input_buffer_.data(), &data, size);

Guolin Ke's avatar
Guolin Ke committed
103
104
  Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
    comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
    const std::tuple<data_size_t, double, double> *p1;
    std::tuple<data_size_t, double, double> *p2;
    while (used_size < len) {
      p1 = reinterpret_cast<const std::tuple<data_size_t, double, double> *>(src);
      p2 = reinterpret_cast<std::tuple<data_size_t, double, double> *>(dst);
      std::get<0>(*p2) = std::get<0>(*p2) + std::get<0>(*p1);
      std::get<1>(*p2) = std::get<1>(*p2) + std::get<1>(*p1);
      std::get<2>(*p2) = std::get<2>(*p2) + std::get<2>(*p1);
      src += type_size;
      dst += type_size;
      used_size += type_size;
    }
  });

Guolin Ke's avatar
Guolin Ke committed
119
  std::memcpy(reinterpret_cast<void*>(&data), output_buffer_.data(), size);
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
127

  // set global sumup info
  smaller_leaf_splits_global_->Init(std::get<1>(data), std::get<2>(data));
  larger_leaf_splits_global_->Init();
  // init global data count in leaf
  global_data_count_in_leaf_[0] = std::get<0>(data);
}

128
129
130
template <typename TREELEARNER_T>
bool VotingParallelTreeLearner<TREELEARNER_T>::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
  if (TREELEARNER_T::BeforeFindBestSplit(tree, left_leaf, right_leaf)) {
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
135
136
    data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
    data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
    if (right_leaf < 0) {
      return true;
    } else if (num_data_in_left_child < num_data_in_right_child) {
      // get local sumup
137
138
      this->smaller_leaf_splits_->Init(left_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
      this->larger_leaf_splits_->Init(right_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
Guolin Ke's avatar
Guolin Ke committed
139
140
    } else {
      // get local sumup
141
142
      this->smaller_leaf_splits_->Init(right_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
      this->larger_leaf_splits_->Init(left_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
Guolin Ke's avatar
Guolin Ke committed
143
144
145
146
147
148
149
    }
    return true;
  } else {
    return false;
  }
}

150
template <typename TREELEARNER_T>
151
void VotingParallelTreeLearner<TREELEARNER_T>::GlobalVoting(int leaf_idx, const std::vector<LightSplitInfo>& splits, std::vector<int>* out) {
Guolin Ke's avatar
Guolin Ke committed
152
153
154
155
156
  out->clear();
  if (leaf_idx < 0) {
    return;
  }
  // get mean number on machines
157
  score_t mean_num_data = GetGlobalDataCountInLeaf(leaf_idx) / static_cast<score_t>(num_machines_);
158
  std::vector<LightSplitInfo> feature_best_split(this->train_data_->num_total_features() , LightSplitInfo());
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
165
166
167
168
169
170
171
  for (auto & split : splits) {
    int fid = split.feature;
    if (fid < 0) {
      continue;
    }
    // weighted gain
    double gain = split.gain * (split.left_count + split.right_count) / mean_num_data;
    if (gain > feature_best_split[fid].gain) {
      feature_best_split[fid] = split;
      feature_best_split[fid].gain = gain;
    }
  }
  // get top k
172
173
  std::vector<LightSplitInfo> top_k_splits;
  ArrayArgs<LightSplitInfo>::MaxK(feature_best_split, top_k_, &top_k_splits);
174
  std::stable_sort(top_k_splits.begin(), top_k_splits.end(), std::greater<LightSplitInfo>());
Guolin Ke's avatar
Guolin Ke committed
175
176
177
178
179
180
181
182
  for (auto& split : top_k_splits) {
    if (split.gain == kMinScore || split.feature == -1) {
      continue;
    }
    out->push_back(split.feature);
  }
}

183
184
185
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vector<int>& smaller_top_features, const std::vector<int>& larger_top_features) {
  for (int i = 0; i < this->num_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    smaller_is_feature_aggregated_[i] = false;
    larger_is_feature_aggregated_[i] = false;
  }
  size_t total_num_features = smaller_top_features.size() + larger_top_features.size();
  size_t average_feature = (total_num_features + num_machines_ - 1) / num_machines_;
  size_t used_num_features = 0, smaller_idx = 0, larger_idx = 0;
  block_start_[0] = 0;
  reduce_scatter_size_ = 0;
  // Copy histogram to buffer, and Get local aggregate features
  for (int i = 0; i < num_machines_; ++i) {
    size_t cur_size = 0, cur_used_features = 0;
    size_t cur_total_feature = std::min(average_feature, total_num_features - used_num_features);
    // copy histograms.
    while (cur_used_features < cur_total_feature) {
      // copy smaller leaf histograms first
      if (smaller_idx < smaller_top_features.size()) {
202
        int inner_feature_index = this->train_data_->InnerFeatureIndex(smaller_top_features[smaller_idx]);
Guolin Ke's avatar
Guolin Ke committed
203
204
205
        ++cur_used_features;
        // mark local aggregated feature
        if (i == rank_) {
Guolin Ke's avatar
Guolin Ke committed
206
207
          smaller_is_feature_aggregated_[inner_feature_index] = true;
          smaller_buffer_read_start_pos_[inner_feature_index] = static_cast<int>(cur_size);
Guolin Ke's avatar
Guolin Ke committed
208
209
        }
        // copy
210
211
212
        std::memcpy(input_buffer_.data() + reduce_scatter_size_, this->smaller_leaf_histogram_array_[inner_feature_index].RawData(), this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
        cur_size += this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
        reduce_scatter_size_ += this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
Guolin Ke's avatar
Guolin Ke committed
213
214
215
216
217
218
219
        ++smaller_idx;
      }
      if (cur_used_features >= cur_total_feature) {
        break;
      }
      // then copy larger leaf histograms
      if (larger_idx < larger_top_features.size()) {
220
        int inner_feature_index = this->train_data_->InnerFeatureIndex(larger_top_features[larger_idx]);
Guolin Ke's avatar
Guolin Ke committed
221
222
223
        ++cur_used_features;
        // mark local aggregated feature
        if (i == rank_) {
Guolin Ke's avatar
Guolin Ke committed
224
225
          larger_is_feature_aggregated_[inner_feature_index] = true;
          larger_buffer_read_start_pos_[inner_feature_index] = static_cast<int>(cur_size);
Guolin Ke's avatar
Guolin Ke committed
226
227
        }
        // copy
228
229
230
        std::memcpy(input_buffer_.data() + reduce_scatter_size_, this->larger_leaf_histogram_array_[inner_feature_index].RawData(), this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
        cur_size += this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
        reduce_scatter_size_ += this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
Guolin Ke's avatar
Guolin Ke committed
231
232
233
234
235
236
237
238
239
240
241
        ++larger_idx;
      }
    }
    used_num_features += cur_used_features;
    block_len_[i] = static_cast<int>(cur_size);
    if (i < num_machines_ - 1) {
      block_start_[i + 1] = block_start_[i] + block_len_[i];
    }
  }
}

242
template <typename TREELEARNER_T>
243
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
Guolin Ke's avatar
Guolin Ke committed
244
  // use local data to find local best splits
245
  std::vector<int8_t> is_feature_used(this->num_features_, 0);
Guolin Ke's avatar
Guolin Ke committed
246
#pragma omp parallel for schedule(static)
247
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
248
    if (!this->col_sampler_.is_feature_used_bytree()[feature_index]) continue;
249
250
251
    if (this->parent_leaf_histogram_array_ != nullptr
      && !this->parent_leaf_histogram_array_[feature_index].is_splittable()) {
      this->smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
Guolin Ke's avatar
Guolin Ke committed
252
253
254
255
256
      continue;
    }
    is_feature_used[feature_index] = 1;
  }
  bool use_subtract = true;
257
  if (this->parent_leaf_histogram_array_ == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
258
259
    use_subtract = false;
  }
Guolin Ke's avatar
Guolin Ke committed
260
  TREELEARNER_T::ConstructHistograms(is_feature_used, use_subtract);
Guolin Ke's avatar
Guolin Ke committed
261

262
263
  std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_);
  std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_);
264
265
  double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_.get());
  double larger_leaf_parent_output = this->GetParentOutput(tree, this->larger_leaf_splits_.get());
266
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
267
  // find splits
Guolin Ke's avatar
Guolin Ke committed
268
#pragma omp parallel for schedule(static)
269
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
270
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
271
    if (!is_feature_used[feature_index]) { continue; }
272
273
274
275
276
    const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
    this->train_data_->FixHistogram(feature_index,
      this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
      this->smaller_leaf_histogram_array_[feature_index].RawData());

277
278
279
280
    this->ComputeBestSplitForFeature(
        this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
        true, this->smaller_leaf_splits_->num_data_in_leaf(),
        this->smaller_leaf_splits_.get(),
281
282
        &smaller_bestsplit_per_features[feature_index],
        smaller_leaf_parent_output);
Guolin Ke's avatar
Guolin Ke committed
283
    // only has root leaf
284
    if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
285
286

    if (use_subtract) {
287
      this->larger_leaf_histogram_array_[feature_index].Subtract(this->smaller_leaf_histogram_array_[feature_index]);
Guolin Ke's avatar
Guolin Ke committed
288
    } else {
289
290
      this->train_data_->FixHistogram(feature_index, this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_hessians(),
        this->larger_leaf_histogram_array_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
291
    }
292
293
294
295
    this->ComputeBestSplitForFeature(
        this->larger_leaf_histogram_array_, feature_index, real_feature_index,
        true, this->larger_leaf_splits_->num_data_in_leaf(),
        this->larger_leaf_splits_.get(),
296
297
        &larger_bestsplit_per_features[feature_index],
        larger_leaf_parent_output);
298
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
299
  }
300
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
301

Guolin Ke's avatar
Guolin Ke committed
302
303
  std::vector<SplitInfo> smaller_top_k_splits, larger_top_k_splits;
  // local voting
Guolin Ke's avatar
Guolin Ke committed
304
305
  ArrayArgs<SplitInfo>::MaxK(smaller_bestsplit_per_features, top_k_, &smaller_top_k_splits);
  ArrayArgs<SplitInfo>::MaxK(larger_bestsplit_per_features, top_k_, &larger_top_k_splits);
306
307
308
309
310
311
312
313

  std::vector<LightSplitInfo> smaller_top_k_light_splits(top_k_);
  std::vector<LightSplitInfo> larger_top_k_light_splits(top_k_);
  for (int i = 0; i < top_k_; ++i) {
    smaller_top_k_light_splits[i].CopyFrom(smaller_top_k_splits[i]);
    larger_top_k_light_splits[i].CopyFrom(larger_top_k_splits[i]);
  }

Guolin Ke's avatar
Guolin Ke committed
314
315
316
  // gather
  int offset = 0;
  for (int i = 0; i < top_k_; ++i) {
317
318
319
320
    std::memcpy(input_buffer_.data() + offset, &smaller_top_k_light_splits[i], sizeof(LightSplitInfo));
    offset += sizeof(LightSplitInfo);
    std::memcpy(input_buffer_.data() + offset, &larger_top_k_light_splits[i], sizeof(LightSplitInfo));
    offset += sizeof(LightSplitInfo);
Guolin Ke's avatar
Guolin Ke committed
321
322
323
  }
  Network::Allgather(input_buffer_.data(), offset, output_buffer_.data());
  // get all top-k from all machines
324
325
  std::vector<LightSplitInfo> smaller_top_k_splits_global;
  std::vector<LightSplitInfo> larger_top_k_splits_global;
Guolin Ke's avatar
Guolin Ke committed
326
327
328
  offset = 0;
  for (int i = 0; i < num_machines_; ++i) {
    for (int j = 0; j < top_k_; ++j) {
329
330
331
332
333
334
      smaller_top_k_splits_global.push_back(LightSplitInfo());
      std::memcpy(&smaller_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
      offset += sizeof(LightSplitInfo);
      larger_top_k_splits_global.push_back(LightSplitInfo());
      std::memcpy(&larger_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
      offset += sizeof(LightSplitInfo);
Guolin Ke's avatar
Guolin Ke committed
335
336
337
338
    }
  }
  // global voting
  std::vector<int> smaller_top_features, larger_top_features;
339
340
  GlobalVoting(this->smaller_leaf_splits_->leaf_index(), smaller_top_k_splits_global, &smaller_top_features);
  GlobalVoting(this->larger_leaf_splits_->leaf_index(), larger_top_k_splits_global, &larger_top_features);
Guolin Ke's avatar
Guolin Ke committed
341
342
343
344
  // copy local histgrams to buffer
  CopyLocalHistogram(smaller_top_features, larger_top_features);

  // Reduce scatter for histogram
345
346
  Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(),
                         output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);
Guolin Ke's avatar
Guolin Ke committed
347

348
  this->FindBestSplitsFromHistograms(is_feature_used, false, tree);
Guolin Ke's avatar
Guolin Ke committed
349
350
351
}

template <typename TREELEARNER_T>
352
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool, const Tree* tree) {
353
354
  std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
  std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
355
  std::vector<int8_t> smaller_node_used_features =
356
      this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
357
  std::vector<int8_t> larger_node_used_features =
358
      this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
359
360
  double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_global_.get());
  double larger_leaf_parent_output = this->GetParentOutput(tree, this->larger_leaf_splits_global_.get());
Guolin Ke's avatar
Guolin Ke committed
361
  // find best split from local aggregated histograms
Guolin Ke's avatar
Guolin Ke committed
362
  OMP_INIT_EX();
363
#pragma omp parallel for schedule(static) num_threads(this->share_state_->num_threads)
364
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
365
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
366
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
367
    const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
Guolin Ke's avatar
Guolin Ke committed
368
369
370
    if (smaller_is_feature_aggregated_[feature_index]) {
      // restore from buffer
      smaller_leaf_histogram_array_global_[feature_index].FromMemory(
Guolin Ke's avatar
Guolin Ke committed
371
        output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
Guolin Ke's avatar
Guolin Ke committed
372

373
      this->train_data_->FixHistogram(feature_index,
Guolin Ke's avatar
Guolin Ke committed
374
375
                                      smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
                                      smaller_leaf_histogram_array_global_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
376

377
378
379
380
      this->ComputeBestSplitForFeature(
          smaller_leaf_histogram_array_global_.get(), feature_index,
          real_feature_index, smaller_node_used_features[feature_index],
          GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->leaf_index()),
381
382
          smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid],
          smaller_leaf_parent_output);
Guolin Ke's avatar
Guolin Ke committed
383
384
385
386
387
    }

    if (larger_is_feature_aggregated_[feature_index]) {
      // restore from buffer
      larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
Guolin Ke's avatar
Guolin Ke committed
388

389
      this->train_data_->FixHistogram(feature_index,
Guolin Ke's avatar
Guolin Ke committed
390
391
                                      larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
                                      larger_leaf_histogram_array_global_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
392

393
394
      this->ComputeBestSplitForFeature(
          larger_leaf_histogram_array_global_.get(), feature_index,
Nikita Titov's avatar
Nikita Titov committed
395
          real_feature_index,
396
397
          larger_node_used_features[feature_index],
          GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()),
398
399
          larger_leaf_splits_global_.get(), &larger_bests_per_thread[tid],
          larger_leaf_parent_output);
Guolin Ke's avatar
Guolin Ke committed
400
    }
401
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
402
  }
403
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
404
405

  auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
406
  int leaf = this->smaller_leaf_splits_->leaf_index();
Guolin Ke's avatar
Guolin Ke committed
407
  this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
Guolin Ke's avatar
Guolin Ke committed
408

409
410
  if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
    leaf = this->larger_leaf_splits_->leaf_index();
411
412
    auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
    this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
Guolin Ke's avatar
Guolin Ke committed
413
  }
Guolin Ke's avatar
Guolin Ke committed
414
415

  // find local best
Guolin Ke's avatar
Guolin Ke committed
416
  SplitInfo smaller_best_split, larger_best_split;
417
  smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
Guolin Ke's avatar
Guolin Ke committed
418
  // find local best split for larger leaf
419
420
  if (this->larger_leaf_splits_->leaf_index() >= 0) {
    larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
Guolin Ke's avatar
Guolin Ke committed
421
422
  }
  // sync global best info
Guolin Ke's avatar
Guolin Ke committed
423
  SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
Guolin Ke's avatar
Guolin Ke committed
424
425

  // copy back
426
427
428
  this->best_split_per_leaf_[smaller_leaf_splits_global_->leaf_index()] = smaller_best_split;
  if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->leaf_index() >= 0) {
    this->best_split_per_leaf_[larger_leaf_splits_global_->leaf_index()] = larger_best_split;
Guolin Ke's avatar
Guolin Ke committed
429
430
431
  }
}

432
433
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
434
  TREELEARNER_T::SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
435
  const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
Guolin Ke's avatar
Guolin Ke committed
436
437
438
439
440
  // set the global number of data for leaves
  global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
  global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
  // init the global sumup info
  if (best_split_info.left_count < best_split_info.right_count) {
441
    smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
442
      best_split_info.left_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
443
444
      best_split_info.left_sum_hessian,
      best_split_info.left_output);
445
    larger_leaf_splits_global_->Init(*right_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
446
      best_split_info.right_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
447
448
      best_split_info.right_sum_hessian,
      best_split_info.right_output);
Guolin Ke's avatar
Guolin Ke committed
449
  } else {
450
    smaller_leaf_splits_global_->Init(*right_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
451
      best_split_info.right_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
452
453
      best_split_info.right_sum_hessian,
      best_split_info.right_output);
454
    larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
455
      best_split_info.left_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
456
457
      best_split_info.left_sum_hessian,
      best_split_info.left_output);
Guolin Ke's avatar
Guolin Ke committed
458
459
460
  }
}

461
// instantiate template classes, otherwise linker cannot find the code
462
template class VotingParallelTreeLearner<CUDATreeLearner>;
463
464
template class VotingParallelTreeLearner<GPUTreeLearner>;
template class VotingParallelTreeLearner<SerialTreeLearner>;
465
}  // namespace LightGBM