voting_parallel_tree_learner.cpp 20.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
9
10
#include <LightGBM/utils/common.h>

#include <cstring>
#include <tuple>
#include <vector>

11
12
#include "parallel_tree_learner.h"

Guolin Ke's avatar
Guolin Ke committed
13
14
namespace LightGBM {

15
template <typename TREELEARNER_T>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
VotingParallelTreeLearner<TREELEARNER_T>::VotingParallelTreeLearner(const Config* config)
  :TREELEARNER_T(config) {
  top_k_ = this->config_->top_k;
Guolin Ke's avatar
Guolin Ke committed
19
20
}

21
22
23
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, bool is_constant_hessian) {
  TREELEARNER_T::Init(train_data, is_constant_hessian);
Guolin Ke's avatar
Guolin Ke committed
24
25
26
27
  rank_ = Network::rank();
  num_machines_ = Network::num_machines();

  // limit top k
28
29
  if (top_k_ > this->num_features_) {
    top_k_ = this->num_features_;
Guolin Ke's avatar
Guolin Ke committed
30
31
32
  }
  // get max bin
  int max_bin = 0;
33
34
35
  for (int i = 0; i < this->num_features_; ++i) {
    if (max_bin < this->train_data_->FeatureNumBin(i)) {
      max_bin = this->train_data_->FeatureNumBin(i);
Guolin Ke's avatar
Guolin Ke committed
36
37
38
    }
  }
  // calculate buffer size
39
  size_t buffer_size = 2 * top_k_ * std::max(max_bin * kHistEntrySize, sizeof(LightSplitInfo) * num_machines_);
Guolin Ke's avatar
Guolin Ke committed
40
41
42
43
  // left and right on same time, so need double size
  input_buffer_.resize(buffer_size);
  output_buffer_.resize(buffer_size);

44
45
  smaller_is_feature_aggregated_.resize(this->num_features_);
  larger_is_feature_aggregated_.resize(this->num_features_);
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49

  block_start_.resize(num_machines_);
  block_len_.resize(num_machines_);

50
51
  smaller_buffer_read_start_pos_.resize(this->num_features_);
  larger_buffer_read_start_pos_.resize(this->num_features_);
Guolin Ke's avatar
Guolin Ke committed
52
  global_data_count_in_leaf_.resize(this->config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
53

54
55
  smaller_leaf_splits_global_.reset(new LeafSplits(train_data->num_data()));
  larger_leaf_splits_global_.reset(new LeafSplits(train_data->num_data()));
Guolin Ke's avatar
Guolin Ke committed
56

Guolin Ke's avatar
Guolin Ke committed
57
58
59
  local_config_ = *this->config_;
  local_config_.min_data_in_leaf /= num_machines_;
  local_config_.min_sum_hessian_in_leaf /= num_machines_;
Guolin Ke's avatar
Guolin Ke committed
60

61
  this->histogram_pool_.ResetConfig(train_data, &local_config_);
Guolin Ke's avatar
Guolin Ke committed
62
63

  // initialize histograms for global
64
65
  smaller_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
  larger_leaf_histogram_array_global_.reset(new FeatureHistogram[this->num_features_]);
66
  auto num_total_bin = train_data->NumTotalBin();
Guolin Ke's avatar
Guolin Ke committed
67
68
  smaller_leaf_histogram_data_.resize(num_total_bin);
  larger_leaf_histogram_data_.resize(num_total_bin);
69
  HistogramPool::SetFeatureInfo<true, true>(train_data, this->config_, &feature_metas_);
Guolin Ke's avatar
Guolin Ke committed
70
71
72
  uint64_t offset = 0;
  for (int j = 0; j < train_data->num_features(); ++j) {
    offset += static_cast<uint64_t>(train_data->SubFeatureBinOffset(j));
73
74
    smaller_leaf_histogram_array_global_[j].Init(smaller_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
    larger_leaf_histogram_array_global_[j].Init(larger_leaf_histogram_data_.data() + offset, &feature_metas_[j]);
Guolin Ke's avatar
Guolin Ke committed
75
    auto num_bin = train_data->FeatureNumBin(j);
Guolin Ke's avatar
Guolin Ke committed
76
    if (train_data->FeatureBinMapper(j)->GetMostFreqBin() == 0) {
Guolin Ke's avatar
Guolin Ke committed
77
78
79
      num_bin -= 1;
    }
    offset += static_cast<uint64_t>(num_bin);
Guolin Ke's avatar
Guolin Ke committed
80
81
82
  }
}

83
template <typename TREELEARNER_T>
Guolin Ke's avatar
Guolin Ke committed
84
85
void VotingParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
  TREELEARNER_T::ResetConfig(config);
Guolin Ke's avatar
Guolin Ke committed
86

Guolin Ke's avatar
Guolin Ke committed
87
88
89
  local_config_ = *this->config_;
  local_config_.min_data_in_leaf /= num_machines_;
  local_config_.min_sum_hessian_in_leaf /= num_machines_;
Guolin Ke's avatar
Guolin Ke committed
90

91
  this->histogram_pool_.ResetConfig(this->train_data_, &local_config_);
Guolin Ke's avatar
Guolin Ke committed
92
  global_data_count_in_leaf_.resize(this->config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
93

94
  HistogramPool::SetFeatureInfo<false, true>(this->train_data_, config, &feature_metas_);
Guolin Ke's avatar
Guolin Ke committed
95
}
Guolin Ke's avatar
Guolin Ke committed
96

97
98
99
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
  TREELEARNER_T::BeforeTrain();
Guolin Ke's avatar
Guolin Ke committed
100
  // sync global data sumup info
101
  std::tuple<data_size_t, double, double> data(this->smaller_leaf_splits_->num_data_in_leaf(), this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians());
Guolin Ke's avatar
Guolin Ke committed
102
103
104
  int size = sizeof(std::tuple<data_size_t, double, double>);
  std::memcpy(input_buffer_.data(), &data, size);

Guolin Ke's avatar
Guolin Ke committed
105
106
  Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
    comm_size_t used_size = 0;
Guolin Ke's avatar
Guolin Ke committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    const std::tuple<data_size_t, double, double> *p1;
    std::tuple<data_size_t, double, double> *p2;
    while (used_size < len) {
      p1 = reinterpret_cast<const std::tuple<data_size_t, double, double> *>(src);
      p2 = reinterpret_cast<std::tuple<data_size_t, double, double> *>(dst);
      std::get<0>(*p2) = std::get<0>(*p2) + std::get<0>(*p1);
      std::get<1>(*p2) = std::get<1>(*p2) + std::get<1>(*p1);
      std::get<2>(*p2) = std::get<2>(*p2) + std::get<2>(*p1);
      src += type_size;
      dst += type_size;
      used_size += type_size;
    }
  });

Guolin Ke's avatar
Guolin Ke committed
121
  std::memcpy(reinterpret_cast<void*>(&data), output_buffer_.data(), size);
Guolin Ke's avatar
Guolin Ke committed
122
123
124
125
126
127
128
129

  // set global sumup info
  smaller_leaf_splits_global_->Init(std::get<1>(data), std::get<2>(data));
  larger_leaf_splits_global_->Init();
  // init global data count in leaf
  global_data_count_in_leaf_[0] = std::get<0>(data);
}

130
131
132
template <typename TREELEARNER_T>
bool VotingParallelTreeLearner<TREELEARNER_T>::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
  if (TREELEARNER_T::BeforeFindBestSplit(tree, left_leaf, right_leaf)) {
Guolin Ke's avatar
Guolin Ke committed
133
134
135
136
137
138
    data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
    data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
    if (right_leaf < 0) {
      return true;
    } else if (num_data_in_left_child < num_data_in_right_child) {
      // get local sumup
139
140
      this->smaller_leaf_splits_->Init(left_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
      this->larger_leaf_splits_->Init(right_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
Guolin Ke's avatar
Guolin Ke committed
141
142
    } else {
      // get local sumup
143
144
      this->smaller_leaf_splits_->Init(right_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
      this->larger_leaf_splits_->Init(left_leaf, this->data_partition_.get(), this->gradients_, this->hessians_);
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
    }
    return true;
  } else {
    return false;
  }
}

152
template <typename TREELEARNER_T>
153
void VotingParallelTreeLearner<TREELEARNER_T>::GlobalVoting(int leaf_idx, const std::vector<LightSplitInfo>& splits, std::vector<int>* out) {
Guolin Ke's avatar
Guolin Ke committed
154
155
156
157
158
  out->clear();
  if (leaf_idx < 0) {
    return;
  }
  // get mean number on machines
159
  score_t mean_num_data = GetGlobalDataCountInLeaf(leaf_idx) / static_cast<score_t>(num_machines_);
160
  std::vector<LightSplitInfo> feature_best_split(this->train_data_->num_total_features() , LightSplitInfo());
Guolin Ke's avatar
Guolin Ke committed
161
162
163
164
165
166
167
168
169
170
171
172
173
  for (auto & split : splits) {
    int fid = split.feature;
    if (fid < 0) {
      continue;
    }
    // weighted gain
    double gain = split.gain * (split.left_count + split.right_count) / mean_num_data;
    if (gain > feature_best_split[fid].gain) {
      feature_best_split[fid] = split;
      feature_best_split[fid].gain = gain;
    }
  }
  // get top k
174
175
  std::vector<LightSplitInfo> top_k_splits;
  ArrayArgs<LightSplitInfo>::MaxK(feature_best_split, top_k_, &top_k_splits);
176
  std::stable_sort(top_k_splits.begin(), top_k_splits.end(), std::greater<LightSplitInfo>());
Guolin Ke's avatar
Guolin Ke committed
177
178
179
180
181
182
183
184
  for (auto& split : top_k_splits) {
    if (split.gain == kMinScore || split.feature == -1) {
      continue;
    }
    out->push_back(split.feature);
  }
}

185
186
187
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::CopyLocalHistogram(const std::vector<int>& smaller_top_features, const std::vector<int>& larger_top_features) {
  for (int i = 0; i < this->num_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    smaller_is_feature_aggregated_[i] = false;
    larger_is_feature_aggregated_[i] = false;
  }
  size_t total_num_features = smaller_top_features.size() + larger_top_features.size();
  size_t average_feature = (total_num_features + num_machines_ - 1) / num_machines_;
  size_t used_num_features = 0, smaller_idx = 0, larger_idx = 0;
  block_start_[0] = 0;
  reduce_scatter_size_ = 0;
  // Copy histogram to buffer, and Get local aggregate features
  for (int i = 0; i < num_machines_; ++i) {
    size_t cur_size = 0, cur_used_features = 0;
    size_t cur_total_feature = std::min(average_feature, total_num_features - used_num_features);
    // copy histograms.
    while (cur_used_features < cur_total_feature) {
      // copy smaller leaf histograms first
      if (smaller_idx < smaller_top_features.size()) {
204
        int inner_feature_index = this->train_data_->InnerFeatureIndex(smaller_top_features[smaller_idx]);
Guolin Ke's avatar
Guolin Ke committed
205
206
207
        ++cur_used_features;
        // mark local aggregated feature
        if (i == rank_) {
Guolin Ke's avatar
Guolin Ke committed
208
209
          smaller_is_feature_aggregated_[inner_feature_index] = true;
          smaller_buffer_read_start_pos_[inner_feature_index] = static_cast<int>(cur_size);
Guolin Ke's avatar
Guolin Ke committed
210
211
        }
        // copy
212
213
214
        std::memcpy(input_buffer_.data() + reduce_scatter_size_, this->smaller_leaf_histogram_array_[inner_feature_index].RawData(), this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
        cur_size += this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
        reduce_scatter_size_ += this->smaller_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
Guolin Ke's avatar
Guolin Ke committed
215
216
217
218
219
220
221
        ++smaller_idx;
      }
      if (cur_used_features >= cur_total_feature) {
        break;
      }
      // then copy larger leaf histograms
      if (larger_idx < larger_top_features.size()) {
222
        int inner_feature_index = this->train_data_->InnerFeatureIndex(larger_top_features[larger_idx]);
Guolin Ke's avatar
Guolin Ke committed
223
224
225
        ++cur_used_features;
        // mark local aggregated feature
        if (i == rank_) {
Guolin Ke's avatar
Guolin Ke committed
226
227
          larger_is_feature_aggregated_[inner_feature_index] = true;
          larger_buffer_read_start_pos_[inner_feature_index] = static_cast<int>(cur_size);
Guolin Ke's avatar
Guolin Ke committed
228
229
        }
        // copy
230
231
232
        std::memcpy(input_buffer_.data() + reduce_scatter_size_, this->larger_leaf_histogram_array_[inner_feature_index].RawData(), this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram());
        cur_size += this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
        reduce_scatter_size_ += this->larger_leaf_histogram_array_[inner_feature_index].SizeOfHistgram();
Guolin Ke's avatar
Guolin Ke committed
233
234
235
236
237
238
239
240
241
242
243
        ++larger_idx;
      }
    }
    used_num_features += cur_used_features;
    block_len_[i] = static_cast<int>(cur_size);
    if (i < num_machines_ - 1) {
      block_start_[i + 1] = block_start_[i] + block_len_[i];
    }
  }
}

244
template <typename TREELEARNER_T>
Guolin Ke's avatar
Guolin Ke committed
245
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
Guolin Ke's avatar
Guolin Ke committed
246
  // use local data to find local best splits
247
  std::vector<int8_t> is_feature_used(this->num_features_, 0);
Guolin Ke's avatar
Guolin Ke committed
248
#pragma omp parallel for schedule(static)
249
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
250
    if (!this->col_sampler_.is_feature_used_bytree()[feature_index]) continue;
251
252
253
    if (this->parent_leaf_histogram_array_ != nullptr
      && !this->parent_leaf_histogram_array_[feature_index].is_splittable()) {
      this->smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
258
      continue;
    }
    is_feature_used[feature_index] = 1;
  }
  bool use_subtract = true;
259
  if (this->parent_leaf_histogram_array_ == nullptr) {
Guolin Ke's avatar
Guolin Ke committed
260
261
    use_subtract = false;
  }
Guolin Ke's avatar
Guolin Ke committed
262
  TREELEARNER_T::ConstructHistograms(is_feature_used, use_subtract);
Guolin Ke's avatar
Guolin Ke committed
263

264
265
  std::vector<SplitInfo> smaller_bestsplit_per_features(this->num_features_);
  std::vector<SplitInfo> larger_bestsplit_per_features(this->num_features_);
Guolin Ke's avatar
Guolin Ke committed
266

267
  OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
268
  // find splits
Guolin Ke's avatar
Guolin Ke committed
269
#pragma omp parallel for schedule(static)
270
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
271
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
272
    if (!is_feature_used[feature_index]) { continue; }
273
274
275
276
277
    const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
    this->train_data_->FixHistogram(feature_index,
      this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
      this->smaller_leaf_histogram_array_[feature_index].RawData());

278
279
280
281
282
    this->ComputeBestSplitForFeature(
        this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
        true, this->smaller_leaf_splits_->num_data_in_leaf(),
        this->smaller_leaf_splits_.get(),
        &smaller_bestsplit_per_features[feature_index]);
Guolin Ke's avatar
Guolin Ke committed
283
    // only has root leaf
284
    if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) { continue; }
Guolin Ke's avatar
Guolin Ke committed
285
286

    if (use_subtract) {
287
      this->larger_leaf_histogram_array_[feature_index].Subtract(this->smaller_leaf_histogram_array_[feature_index]);
Guolin Ke's avatar
Guolin Ke committed
288
    } else {
289
290
      this->train_data_->FixHistogram(feature_index, this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_hessians(),
        this->larger_leaf_histogram_array_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
291
    }
292
293
294
295
296
    this->ComputeBestSplitForFeature(
        this->larger_leaf_histogram_array_, feature_index, real_feature_index,
        true, this->larger_leaf_splits_->num_data_in_leaf(),
        this->larger_leaf_splits_.get(),
        &larger_bestsplit_per_features[feature_index]);
297
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
298
  }
299
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
300

Guolin Ke's avatar
Guolin Ke committed
301
302
  std::vector<SplitInfo> smaller_top_k_splits, larger_top_k_splits;
  // local voting
Guolin Ke's avatar
Guolin Ke committed
303
304
  ArrayArgs<SplitInfo>::MaxK(smaller_bestsplit_per_features, top_k_, &smaller_top_k_splits);
  ArrayArgs<SplitInfo>::MaxK(larger_bestsplit_per_features, top_k_, &larger_top_k_splits);
305
306
307
308
309
310
311
312

  std::vector<LightSplitInfo> smaller_top_k_light_splits(top_k_);
  std::vector<LightSplitInfo> larger_top_k_light_splits(top_k_);
  for (int i = 0; i < top_k_; ++i) {
    smaller_top_k_light_splits[i].CopyFrom(smaller_top_k_splits[i]);
    larger_top_k_light_splits[i].CopyFrom(larger_top_k_splits[i]);
  }

Guolin Ke's avatar
Guolin Ke committed
313
314
315
  // gather
  int offset = 0;
  for (int i = 0; i < top_k_; ++i) {
316
317
318
319
    std::memcpy(input_buffer_.data() + offset, &smaller_top_k_light_splits[i], sizeof(LightSplitInfo));
    offset += sizeof(LightSplitInfo);
    std::memcpy(input_buffer_.data() + offset, &larger_top_k_light_splits[i], sizeof(LightSplitInfo));
    offset += sizeof(LightSplitInfo);
Guolin Ke's avatar
Guolin Ke committed
320
321
322
  }
  Network::Allgather(input_buffer_.data(), offset, output_buffer_.data());
  // get all top-k from all machines
323
324
  std::vector<LightSplitInfo> smaller_top_k_splits_global;
  std::vector<LightSplitInfo> larger_top_k_splits_global;
Guolin Ke's avatar
Guolin Ke committed
325
326
327
  offset = 0;
  for (int i = 0; i < num_machines_; ++i) {
    for (int j = 0; j < top_k_; ++j) {
328
329
330
331
332
333
      smaller_top_k_splits_global.push_back(LightSplitInfo());
      std::memcpy(&smaller_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
      offset += sizeof(LightSplitInfo);
      larger_top_k_splits_global.push_back(LightSplitInfo());
      std::memcpy(&larger_top_k_splits_global.back(), output_buffer_.data() + offset, sizeof(LightSplitInfo));
      offset += sizeof(LightSplitInfo);
Guolin Ke's avatar
Guolin Ke committed
334
335
336
337
    }
  }
  // global voting
  std::vector<int> smaller_top_features, larger_top_features;
338
339
  GlobalVoting(this->smaller_leaf_splits_->leaf_index(), smaller_top_k_splits_global, &smaller_top_features);
  GlobalVoting(this->larger_leaf_splits_->leaf_index(), larger_top_k_splits_global, &larger_top_features);
Guolin Ke's avatar
Guolin Ke committed
340
341
342
343
  // copy local histgrams to buffer
  CopyLocalHistogram(smaller_top_features, larger_top_features);

  // Reduce scatter for histogram
344
345
  Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(), block_len_.data(),
                         output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);
Guolin Ke's avatar
Guolin Ke committed
346

Guolin Ke's avatar
Guolin Ke committed
347
348
349
350
351
  this->FindBestSplitsFromHistograms(is_feature_used, false);
}

template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool) {
352
353
  std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
  std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
354
355
356
357
  std::vector<int8_t> smaller_node_used_features =
      this->col_sampler_.GetByNode();
  std::vector<int8_t> larger_node_used_features =
      this->col_sampler_.GetByNode();
Guolin Ke's avatar
Guolin Ke committed
358
  // find best split from local aggregated histograms
Guolin Ke's avatar
Guolin Ke committed
359
360

  OMP_INIT_EX();
361
#pragma omp parallel for schedule(static) num_threads(this->share_state_->num_threads)
362
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
363
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
364
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
365
    const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
Guolin Ke's avatar
Guolin Ke committed
366
367
368
    if (smaller_is_feature_aggregated_[feature_index]) {
      // restore from buffer
      smaller_leaf_histogram_array_global_[feature_index].FromMemory(
Guolin Ke's avatar
Guolin Ke committed
369
        output_buffer_.data() + smaller_buffer_read_start_pos_[feature_index]);
Guolin Ke's avatar
Guolin Ke committed
370

371
      this->train_data_->FixHistogram(feature_index,
Guolin Ke's avatar
Guolin Ke committed
372
373
                                      smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_hessians(),
                                      smaller_leaf_histogram_array_global_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
374

375
376
377
378
379
      this->ComputeBestSplitForFeature(
          smaller_leaf_histogram_array_global_.get(), feature_index,
          real_feature_index, smaller_node_used_features[feature_index],
          GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->leaf_index()),
          smaller_leaf_splits_global_.get(), &smaller_bests_per_thread[tid]);
Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
384
    }

    if (larger_is_feature_aggregated_[feature_index]) {
      // restore from buffer
      larger_leaf_histogram_array_global_[feature_index].FromMemory(output_buffer_.data() + larger_buffer_read_start_pos_[feature_index]);
Guolin Ke's avatar
Guolin Ke committed
385

386
      this->train_data_->FixHistogram(feature_index,
Guolin Ke's avatar
Guolin Ke committed
387
388
                                      larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_hessians(),
                                      larger_leaf_histogram_array_global_[feature_index].RawData());
Guolin Ke's avatar
Guolin Ke committed
389

390
391
      this->ComputeBestSplitForFeature(
          larger_leaf_histogram_array_global_.get(), feature_index,
Nikita Titov's avatar
Nikita Titov committed
392
          real_feature_index,
393
394
          larger_node_used_features[feature_index],
          GetGlobalDataCountInLeaf(larger_leaf_splits_global_->leaf_index()),
395
          larger_leaf_splits_global_.get(), &larger_bests_per_thread[tid]);
Guolin Ke's avatar
Guolin Ke committed
396
    }
397
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
398
  }
399
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
400
401

  auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
402
  int leaf = this->smaller_leaf_splits_->leaf_index();
Guolin Ke's avatar
Guolin Ke committed
403
  this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
Guolin Ke's avatar
Guolin Ke committed
404

405
406
  if (this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
    leaf = this->larger_leaf_splits_->leaf_index();
407
408
    auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
    this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
Guolin Ke's avatar
Guolin Ke committed
409
  }
Guolin Ke's avatar
Guolin Ke committed
410
411

  // find local best
Guolin Ke's avatar
Guolin Ke committed
412
  SplitInfo smaller_best_split, larger_best_split;
413
  smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
Guolin Ke's avatar
Guolin Ke committed
414
  // find local best split for larger leaf
415
416
  if (this->larger_leaf_splits_->leaf_index() >= 0) {
    larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
Guolin Ke's avatar
Guolin Ke committed
417
418
  }
  // sync global best info
Guolin Ke's avatar
Guolin Ke committed
419
  SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
Guolin Ke's avatar
Guolin Ke committed
420
421

  // copy back
422
423
424
  this->best_split_per_leaf_[smaller_leaf_splits_global_->leaf_index()] = smaller_best_split;
  if (larger_best_split.feature >= 0 && larger_leaf_splits_global_->leaf_index() >= 0) {
    this->best_split_per_leaf_[larger_leaf_splits_global_->leaf_index()] = larger_best_split;
Guolin Ke's avatar
Guolin Ke committed
425
426
427
  }
}

428
429
template <typename TREELEARNER_T>
void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
430
  TREELEARNER_T::SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
431
  const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
Guolin Ke's avatar
Guolin Ke committed
432
433
434
435
436
  // set the global number of data for leaves
  global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
  global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
  // init the global sumup info
  if (best_split_info.left_count < best_split_info.right_count) {
437
    smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
438
      best_split_info.left_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
439
440
      best_split_info.left_sum_hessian,
      best_split_info.left_output);
441
    larger_leaf_splits_global_->Init(*right_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
442
      best_split_info.right_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
443
444
      best_split_info.right_sum_hessian,
      best_split_info.right_output);
Guolin Ke's avatar
Guolin Ke committed
445
  } else {
446
    smaller_leaf_splits_global_->Init(*right_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
447
      best_split_info.right_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
448
449
      best_split_info.right_sum_hessian,
      best_split_info.right_output);
450
    larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
Guolin Ke's avatar
Guolin Ke committed
451
      best_split_info.left_sum_gradient,
Belinda Trotta's avatar
Belinda Trotta committed
452
453
      best_split_info.left_sum_hessian,
      best_split_info.left_output);
Guolin Ke's avatar
Guolin Ke committed
454
455
456
  }
}

457
458
459
// instantiate template classes, otherwise linker cannot find the code
template class VotingParallelTreeLearner<GPUTreeLearner>;
template class VotingParallelTreeLearner<SerialTreeLearner>;
460
}  // namespace LightGBM