data_parallel_tree_learner.cpp 23.6 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
#include <algorithm>
Guolin Ke's avatar
Guolin Ke committed
6
7
8
9
#include <cstring>
#include <tuple>
#include <vector>

10
11
#include "parallel_tree_learner.h"

Guolin Ke's avatar
Guolin Ke committed
12
13
namespace LightGBM {

14
template <typename TREELEARNER_T>
Guolin Ke's avatar
Guolin Ke committed
15
16
DataParallelTreeLearner<TREELEARNER_T>::DataParallelTreeLearner(const Config* config)
  :TREELEARNER_T(config) {
Guolin Ke's avatar
Guolin Ke committed
17
18
}

19
20
template <typename TREELEARNER_T>
DataParallelTreeLearner<TREELEARNER_T>::~DataParallelTreeLearner() {
Guolin Ke's avatar
Guolin Ke committed
21
22
}

23
24
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, bool is_constant_hessian) {
Guolin Ke's avatar
Guolin Ke committed
25
  // initialize SerialTreeLearner
26
  TREELEARNER_T::Init(train_data, is_constant_hessian);
Guolin Ke's avatar
Guolin Ke committed
27
28
29
  // Get local rank and global machine size
  rank_ = Network::rank();
  num_machines_ = Network::num_machines();
30
31
32
33

  auto max_cat_threshold = this->config_->max_cat_threshold;
  // need to be able to hold smaller and larger best splits in SyncUpGlobalBestSplit
  size_t split_info_size = static_cast<size_t>(SplitInfo::Size(max_cat_threshold) * 2);
34
35
36
  size_t histogram_size = this->config_->use_quantized_grad ?
    static_cast<size_t>(this->share_state_->num_hist_total_bin() * kInt32HistEntrySize) :
    static_cast<size_t>(this->share_state_->num_hist_total_bin() * kHistEntrySize);
37

Guolin Ke's avatar
Guolin Ke committed
38
  // allocate buffer for communication
39
  size_t buffer_size = std::max(histogram_size, split_info_size);
Guolin Ke's avatar
Guolin Ke committed
40

Guolin Ke's avatar
Guolin Ke committed
41
42
  input_buffer_.resize(buffer_size);
  output_buffer_.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
43

44
  is_feature_aggregated_.resize(this->num_features_);
Guolin Ke's avatar
Guolin Ke committed
45

Guolin Ke's avatar
Guolin Ke committed
46
47
  block_start_.resize(num_machines_);
  block_len_.resize(num_machines_);
Guolin Ke's avatar
Guolin Ke committed
48

49
50
51
52
53
  if (this->config_->use_quantized_grad) {
    block_start_int16_.resize(num_machines_);
    block_len_int16_.resize(num_machines_);
  }

54
55
  buffer_write_start_pos_.resize(this->num_features_);
  buffer_read_start_pos_.resize(this->num_features_);
56
57
58
59
60
61

  if (this->config_->use_quantized_grad) {
    buffer_write_start_pos_int16_.resize(this->num_features_);
    buffer_read_start_pos_int16_.resize(this->num_features_);
  }

Guolin Ke's avatar
Guolin Ke committed
62
  global_data_count_in_leaf_.resize(this->config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
63
64
}

65
template <typename TREELEARNER_T>
Guolin Ke's avatar
Guolin Ke committed
66
67
68
void DataParallelTreeLearner<TREELEARNER_T>::ResetConfig(const Config* config) {
  TREELEARNER_T::ResetConfig(config);
  global_data_count_in_leaf_.resize(this->config_->num_leaves);
Guolin Ke's avatar
Guolin Ke committed
69
}
Guolin Ke's avatar
Guolin Ke committed
70

71
template <typename TREELEARNER_T>
72
73
74
75
76
77
78
79
void DataParallelTreeLearner<TREELEARNER_T>::PrepareBufferPos(
  const std::vector<std::vector<int>>& feature_distribution,
  std::vector<comm_size_t>* block_start,
  std::vector<comm_size_t>* block_len,
  std::vector<comm_size_t>* buffer_write_start_pos,
  std::vector<comm_size_t>* buffer_read_start_pos,
  comm_size_t* reduce_scatter_size,
  size_t hist_entry_size) {
Guolin Ke's avatar
Guolin Ke committed
80
  // get block start and block len for reduce scatter
81
  *reduce_scatter_size = 0;
Guolin Ke's avatar
Guolin Ke committed
82
  for (int i = 0; i < num_machines_; ++i) {
83
    (*block_len)[i] = 0;
Guolin Ke's avatar
Guolin Ke committed
84
    for (auto fid : feature_distribution[i]) {
85
      auto num_bin = this->train_data_->FeatureNumBin(fid);
Guolin Ke's avatar
Guolin Ke committed
86
      if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
Guolin Ke's avatar
Guolin Ke committed
87
88
        num_bin -= 1;
      }
89
      (*block_len)[i] += num_bin * hist_entry_size;
Guolin Ke's avatar
Guolin Ke committed
90
    }
91
    *reduce_scatter_size += (*block_len)[i];
Guolin Ke's avatar
Guolin Ke committed
92
93
  }

94
  (*block_start)[0] = 0;
Guolin Ke's avatar
Guolin Ke committed
95
  for (int i = 1; i < num_machines_; ++i) {
96
    (*block_start)[i] = (*block_start)[i - 1] + (*block_len)[i - 1];
Guolin Ke's avatar
Guolin Ke committed
97
98
  }

99
  // get buffer_write_start_pos
Guolin Ke's avatar
Guolin Ke committed
100
101
102
  int bin_size = 0;
  for (int i = 0; i < num_machines_; ++i) {
    for (auto fid : feature_distribution[i]) {
103
      (*buffer_write_start_pos)[fid] = bin_size;
104
      auto num_bin = this->train_data_->FeatureNumBin(fid);
Guolin Ke's avatar
Guolin Ke committed
105
      if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
Guolin Ke's avatar
Guolin Ke committed
106
107
        num_bin -= 1;
      }
108
      bin_size += num_bin * hist_entry_size;
Guolin Ke's avatar
Guolin Ke committed
109
110
111
    }
  }

112
  // get buffer_read_start_pos
Guolin Ke's avatar
Guolin Ke committed
113
114
  bin_size = 0;
  for (auto fid : feature_distribution[rank_]) {
115
    (*buffer_read_start_pos)[fid] = bin_size;
116
    auto num_bin = this->train_data_->FeatureNumBin(fid);
Guolin Ke's avatar
Guolin Ke committed
117
    if (this->train_data_->FeatureBinMapper(fid)->GetMostFreqBin() == 0) {
Guolin Ke's avatar
Guolin Ke committed
118
119
      num_bin -= 1;
    }
120
    bin_size += num_bin * hist_entry_size;
Guolin Ke's avatar
Guolin Ke committed
121
  }
122
}
Guolin Ke's avatar
Guolin Ke committed
123

124
125
126
127
128
129
130
131
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::BeforeTrain() {
  TREELEARNER_T::BeforeTrain();
  // generate feature partition for current tree
  std::vector<std::vector<int>> feature_distribution(num_machines_, std::vector<int>());
  std::vector<int> num_bins_distributed(num_machines_, 0);
  for (int i = 0; i < this->train_data_->num_total_features(); ++i) {
    int inner_feature_index = this->train_data_->InnerFeatureIndex(i);
132
133
134
    if (inner_feature_index == -1) {
      continue;
    }
135
136
137
138
139
140
141
142
    if (this->col_sampler_.is_feature_used_bytree()[inner_feature_index]) {
      int cur_min_machine = static_cast<int>(ArrayArgs<int>::ArgMin(num_bins_distributed));
      feature_distribution[cur_min_machine].push_back(inner_feature_index);
      auto num_bin = this->train_data_->FeatureNumBin(inner_feature_index);
      if (this->train_data_->FeatureBinMapper(inner_feature_index)->GetMostFreqBin() == 0) {
        num_bin -= 1;
      }
      num_bins_distributed[cur_min_machine] += num_bin;
Guolin Ke's avatar
Guolin Ke committed
143
    }
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    is_feature_aggregated_[inner_feature_index] = false;
  }
  // get local used feature
  for (auto fid : feature_distribution[rank_]) {
    is_feature_aggregated_[fid] = true;
  }

  // get block start and block len for reduce scatter
  if (this->config_->use_quantized_grad) {
    PrepareBufferPos(feature_distribution, &block_start_, &block_len_, &buffer_write_start_pos_,
      &buffer_read_start_pos_, &reduce_scatter_size_, kInt32HistEntrySize);
    PrepareBufferPos(feature_distribution, &block_start_int16_, &block_len_int16_, &buffer_write_start_pos_int16_,
      &buffer_read_start_pos_int16_, &reduce_scatter_size_int16_, kInt16HistEntrySize);
  } else {
    PrepareBufferPos(feature_distribution, &block_start_, &block_len_, &buffer_write_start_pos_,
      &buffer_read_start_pos_, &reduce_scatter_size_, kHistEntrySize);
  }

  if (this->config_->use_quantized_grad) {
    // sync global data sumup info
    std::tuple<data_size_t, double, double, int64_t> data(this->smaller_leaf_splits_->num_data_in_leaf(),
                                                          this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
                                                          this->smaller_leaf_splits_->int_sum_gradients_and_hessians());
    int size = sizeof(data);
    std::memcpy(input_buffer_.data(), &data, size);
    // global sumup reduce
    Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double, int64_t>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const std::tuple<data_size_t, double, double, int64_t> *p1;
      std::tuple<data_size_t, double, double, int64_t> *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const std::tuple<data_size_t, double, double, int64_t> *>(src);
        p2 = reinterpret_cast<std::tuple<data_size_t, double, double, int64_t> *>(dst);
        std::get<0>(*p2) = std::get<0>(*p2) + std::get<0>(*p1);
        std::get<1>(*p2) = std::get<1>(*p2) + std::get<1>(*p1);
        std::get<2>(*p2) = std::get<2>(*p2) + std::get<2>(*p1);
        std::get<3>(*p2) = std::get<3>(*p2) + std::get<3>(*p1);
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    // copy back
    std::memcpy(reinterpret_cast<void*>(&data), output_buffer_.data(), size);
    // set global sumup info
    this->smaller_leaf_splits_->Init(std::get<1>(data), std::get<2>(data), std::get<3>(data));
    // init global data count in leaf
    global_data_count_in_leaf_[0] = std::get<0>(data);
    // reset hist num bits according to global num data
    this->gradient_discretizer_->template SetNumBitsInHistogramBin<true>(0, -1, GetGlobalDataCountInLeaf(0), 0);
  } else {
    // sync global data sumup info
    std::tuple<data_size_t, double, double> data(this->smaller_leaf_splits_->num_data_in_leaf(),
                                                this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians());
    int size = sizeof(data);
    std::memcpy(input_buffer_.data(), &data, size);
    // global sumup reduce
    Network::Allreduce(input_buffer_.data(), size, sizeof(std::tuple<data_size_t, double, double>), output_buffer_.data(), [](const char *src, char *dst, int type_size, comm_size_t len) {
      comm_size_t used_size = 0;
      const std::tuple<data_size_t, double, double> *p1;
      std::tuple<data_size_t, double, double> *p2;
      while (used_size < len) {
        p1 = reinterpret_cast<const std::tuple<data_size_t, double, double> *>(src);
        p2 = reinterpret_cast<std::tuple<data_size_t, double, double> *>(dst);
        std::get<0>(*p2) = std::get<0>(*p2) + std::get<0>(*p1);
        std::get<1>(*p2) = std::get<1>(*p2) + std::get<1>(*p1);
        std::get<2>(*p2) = std::get<2>(*p2) + std::get<2>(*p1);
        src += type_size;
        dst += type_size;
        used_size += type_size;
      }
    });
    // copy back
    std::memcpy(reinterpret_cast<void*>(&data), output_buffer_.data(), size);
    // set global sumup info
    this->smaller_leaf_splits_->Init(std::get<1>(data), std::get<2>(data));
    // init global data count in leaf
    global_data_count_in_leaf_[0] = std::get<0>(data);
  }
Guolin Ke's avatar
Guolin Ke committed
223
224
}

225
template <typename TREELEARNER_T>
226
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplits(const Tree* tree) {
227
228
  TREELEARNER_T::ConstructHistograms(
      this->col_sampler_.is_feature_used_bytree(), true);
229
230
231
232
233
  const int smaller_leaf_index = this->smaller_leaf_splits_->leaf_index();
  const data_size_t local_data_on_smaller_leaf = this->data_partition_->leaf_count(smaller_leaf_index);
  if (local_data_on_smaller_leaf <= 0) {
    // clear histogram buffer before synchronizing
    // otherwise histogram contents from the previous iteration will be sent
234
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
235
236
237
238
239
240
    for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
      if (this->col_sampler_.is_feature_used_bytree()[feature_index] == false)
        continue;
      const BinMapper* feature_bin_mapper = this->train_data_->FeatureBinMapper(feature_index);
      const int offset = static_cast<int>(feature_bin_mapper->GetMostFreqBin() == 0);
      const int num_bin = feature_bin_mapper->num_bin();
241
242
243
244
245
246
247
248
249
      if (this->config_->use_quantized_grad) {
        int32_t* hist_ptr = this->smaller_leaf_histogram_array_[feature_index].RawDataInt32();
        std::memset(reinterpret_cast<void*>(hist_ptr), 0, (num_bin - offset) * kInt32HistEntrySize);
        int16_t* hist_ptr_int16 = this->smaller_leaf_histogram_array_[feature_index].RawDataInt16();
        std::memset(reinterpret_cast<void*>(hist_ptr_int16), 0, (num_bin - offset) * kInt16HistEntrySize);
      } else {
        hist_t* hist_ptr = this->smaller_leaf_histogram_array_[feature_index].RawData();
        std::memset(reinterpret_cast<void*>(hist_ptr), 0, (num_bin - offset) * kHistEntrySize);
      }
250
251
    }
  }
Guolin Ke's avatar
Guolin Ke committed
252
  // construct local histograms
253
254
  global_timer.Start("DataParallelTreeLearner::ReduceHistogram");
  global_timer.Start("DataParallelTreeLearner::ReduceHistogram::Copy");
255
  #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
256
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
257
258
    if (this->col_sampler_.is_feature_used_bytree()[feature_index] == false)
      continue;
Guolin Ke's avatar
Guolin Ke committed
259
    // copy to buffer
260
261
262
263
264
265
    if (this->config_->use_quantized_grad) {
      const uint8_t local_smaller_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<false>(this->smaller_leaf_splits_->leaf_index());
      const uint8_t smaller_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->smaller_leaf_splits_->leaf_index());
      if (smaller_leaf_num_bits <= 16) {
        std::memcpy(input_buffer_.data() + buffer_write_start_pos_int16_[feature_index],
                    this->smaller_leaf_histogram_array_[feature_index].RawDataInt16(),
266
                    this->smaller_leaf_histogram_array_[feature_index].SizeOfInt16Histogram());
267
268
269
270
      } else {
        if (local_smaller_leaf_num_bits == 32) {
          std::memcpy(input_buffer_.data() + buffer_write_start_pos_[feature_index],
                      this->smaller_leaf_histogram_array_[feature_index].RawDataInt32(),
271
                      this->smaller_leaf_histogram_array_[feature_index].SizeOfInt32Histogram());
272
273
274
275
276
277
278
        } else {
          this->smaller_leaf_histogram_array_[feature_index].CopyFromInt16ToInt32(
            input_buffer_.data() + buffer_write_start_pos_[feature_index]);
        }
      }
    } else {
      std::memcpy(input_buffer_.data() + buffer_write_start_pos_[feature_index],
279
                this->smaller_leaf_histogram_array_[feature_index].RawData(),
280
                this->smaller_leaf_histogram_array_[feature_index].SizeOfHistogram());
281
    }
Guolin Ke's avatar
Guolin Ke committed
282
  }
283
  global_timer.Stop("DataParallelTreeLearner::ReduceHistogram::Copy");
Guolin Ke's avatar
Guolin Ke committed
284
  // Reduce scatter for histogram
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
  global_timer.Start("DataParallelTreeLearner::ReduceHistogram::ReduceScatter");
  if (!this->config_->use_quantized_grad) {
    Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(hist_t), block_start_.data(),
                           block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &HistogramSumReducer);
  } else {
    const uint8_t smaller_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->smaller_leaf_splits_->leaf_index());
    if (smaller_leaf_num_bits <= 16) {
      Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_int16_, sizeof(int16_t), block_start_int16_.data(),
                            block_len_int16_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &Int16HistogramSumReducer);
    } else {
      Network::ReduceScatter(input_buffer_.data(), reduce_scatter_size_, sizeof(int_hist_t), block_start_.data(),
                            block_len_.data(), output_buffer_.data(), static_cast<comm_size_t>(output_buffer_.size()), &Int32HistogramSumReducer);
    }
  }
  global_timer.Stop("DataParallelTreeLearner::ReduceHistogram::ReduceScatter");
  global_timer.Stop("DataParallelTreeLearner::ReduceHistogram");
301
  this->FindBestSplitsFromHistograms(
302
      this->col_sampler_.is_feature_used_bytree(), true, tree);
Guolin Ke's avatar
Guolin Ke committed
303
304
305
}

template <typename TREELEARNER_T>
306
void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const std::vector<int8_t>&, bool, const Tree* tree) {
307
308
  std::vector<SplitInfo> smaller_bests_per_thread(this->share_state_->num_threads);
  std::vector<SplitInfo> larger_bests_per_thread(this->share_state_->num_threads);
309
  std::vector<int8_t> smaller_node_used_features =
310
      this->col_sampler_.GetByNode(tree, this->smaller_leaf_splits_->leaf_index());
311
  std::vector<int8_t> larger_node_used_features =
312
      this->col_sampler_.GetByNode(tree, this->larger_leaf_splits_->leaf_index());
313
314
  double smaller_leaf_parent_output = this->GetParentOutput(tree, this->smaller_leaf_splits_.get());
  double larger_leaf_parent_output = this->GetParentOutput(tree, this->larger_leaf_splits_.get());
315
316
317
318
319
320
321
322
323

  if (this->config_->use_quantized_grad && this->larger_leaf_splits_ != nullptr && this->larger_leaf_splits_->leaf_index() >= 0) {
    const int parent_index = std::min(this->smaller_leaf_splits_->leaf_index(), this->larger_leaf_splits_->leaf_index());
    const uint8_t parent_num_bits = this->gradient_discretizer_->template GetHistBitsInNode<true>(parent_index);
    const uint8_t larger_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->larger_leaf_splits_->leaf_index());
    const uint8_t smaller_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->smaller_leaf_splits_->leaf_index());
    if (parent_num_bits > 16 && larger_leaf_num_bits <= 16) {
      CHECK_LE(smaller_leaf_num_bits, 16);
      OMP_INIT_EX();
324
      #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
325
326
327
328
329
330
331
332
333
334
      for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
        OMP_LOOP_EX_BEGIN();
        if (!is_feature_aggregated_[feature_index]) continue;
        this->larger_leaf_histogram_array_[feature_index].CopyToBuffer(this->gradient_discretizer_->GetChangeHistBitsBuffer(feature_index));
        OMP_LOOP_EX_END();
      }
      OMP_THROW_EX();
    }
  }

335
  OMP_INIT_EX();
336
  #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static)
337
  for (int feature_index = 0; feature_index < this->num_features_; ++feature_index) {
338
    OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
339
    if (!is_feature_aggregated_[feature_index]) continue;
Guolin Ke's avatar
Guolin Ke committed
340
    const int tid = omp_get_thread_num();
Guolin Ke's avatar
Guolin Ke committed
341
    const int real_feature_index = this->train_data_->RealFeatureIndex(feature_index);
Guolin Ke's avatar
Guolin Ke committed
342
    // restore global histograms from buffer
343
344
345
346
347
348
349
350
351
352
353
354
355
    if (this->config_->use_quantized_grad) {
      const uint8_t smaller_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->smaller_leaf_splits_->leaf_index());
      if (smaller_leaf_num_bits <= 16) {
        this->smaller_leaf_histogram_array_[feature_index].FromMemoryInt16(
          output_buffer_.data() + buffer_read_start_pos_int16_[feature_index]);
      } else {
        this->smaller_leaf_histogram_array_[feature_index].FromMemoryInt32(
          output_buffer_.data() + buffer_read_start_pos_[feature_index]);
      }
    } else {
      this->smaller_leaf_histogram_array_[feature_index].FromMemory(
        output_buffer_.data() + buffer_read_start_pos_[feature_index]);
    }
Guolin Ke's avatar
Guolin Ke committed
356

357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    if (this->config_->use_quantized_grad) {
      const uint8_t smaller_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->smaller_leaf_splits_->leaf_index());
      const int64_t int_sum_gradient_and_hessian = this->smaller_leaf_splits_->int_sum_gradients_and_hessians();
      if (smaller_leaf_num_bits <= 16) {
        this->train_data_->template FixHistogramInt<int32_t, int32_t, 16, 16>(
          feature_index,
          int_sum_gradient_and_hessian,
          reinterpret_cast<hist_t*>(this->smaller_leaf_histogram_array_[feature_index].RawDataInt16()));
      } else {
        this->train_data_->template FixHistogramInt<int64_t, int64_t, 32, 32>(
          feature_index,
          int_sum_gradient_and_hessian,
          reinterpret_cast<hist_t*>(this->smaller_leaf_histogram_array_[feature_index].RawDataInt32()));
      }
    } else {
      this->train_data_->FixHistogram(feature_index,
                                      this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_hessians(),
                                      this->smaller_leaf_histogram_array_[feature_index].RawData());
    }
376
377
378
379
380
381

    this->ComputeBestSplitForFeature(
        this->smaller_leaf_histogram_array_, feature_index, real_feature_index,
        smaller_node_used_features[feature_index],
        GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->leaf_index()),
        this->smaller_leaf_splits_.get(),
382
383
        &smaller_bests_per_thread[tid],
        smaller_leaf_parent_output);
Guolin Ke's avatar
Guolin Ke committed
384
385

    // only root leaf
386
    if (this->larger_leaf_splits_ == nullptr || this->larger_leaf_splits_->leaf_index() < 0) continue;
Guolin Ke's avatar
Guolin Ke committed
387
388

    // construct histgroms for large leaf, we init larger leaf as the parent, so we can just subtract the smaller leaf's histograms
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    if (this->config_->use_quantized_grad) {
      const int parent_index = std::min(this->smaller_leaf_splits_->leaf_index(), this->larger_leaf_splits_->leaf_index());
      const uint8_t parent_num_bits = this->gradient_discretizer_->template GetHistBitsInNode<true>(parent_index);
      const uint8_t larger_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->larger_leaf_splits_->leaf_index());
      const uint8_t smaller_leaf_num_bits = this->gradient_discretizer_->template GetHistBitsInLeaf<true>(this->smaller_leaf_splits_->leaf_index());
      if (parent_num_bits <= 16) {
        CHECK_LE(smaller_leaf_num_bits, 16);
        CHECK_LE(larger_leaf_num_bits, 16);
        this->larger_leaf_histogram_array_[feature_index].template Subtract<true, int32_t, int32_t, int32_t, 16, 16, 16>(
              this->smaller_leaf_histogram_array_[feature_index]);
      } else if (larger_leaf_num_bits <= 16) {
        CHECK_LE(smaller_leaf_num_bits, 16);
        this->larger_leaf_histogram_array_[feature_index].template Subtract<true, int64_t, int32_t, int32_t, 32, 16, 16>(
            this->smaller_leaf_histogram_array_[feature_index], this->gradient_discretizer_->GetChangeHistBitsBuffer(feature_index));
      } else if (smaller_leaf_num_bits <= 16) {
        this->larger_leaf_histogram_array_[feature_index].template Subtract<true, int64_t, int32_t, int64_t, 32, 16, 32>(
              this->smaller_leaf_histogram_array_[feature_index]);
      } else {
        this->larger_leaf_histogram_array_[feature_index].template Subtract<true, int64_t, int64_t, int64_t, 32, 32, 32>(
              this->smaller_leaf_histogram_array_[feature_index]);
      }
    } else {
      this->larger_leaf_histogram_array_[feature_index].Subtract(
        this->smaller_leaf_histogram_array_[feature_index]);
    }
414
415
416
417
418
419

    this->ComputeBestSplitForFeature(
        this->larger_leaf_histogram_array_, feature_index, real_feature_index,
        larger_node_used_features[feature_index],
        GetGlobalDataCountInLeaf(this->larger_leaf_splits_->leaf_index()),
        this->larger_leaf_splits_.get(),
420
421
        &larger_bests_per_thread[tid],
        larger_leaf_parent_output);
422
    OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
423
  }
424
  OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
425

Guolin Ke's avatar
Guolin Ke committed
426
  auto smaller_best_idx = ArrayArgs<SplitInfo>::ArgMax(smaller_bests_per_thread);
427
  int leaf = this->smaller_leaf_splits_->leaf_index();
Guolin Ke's avatar
Guolin Ke committed
428
  this->best_split_per_leaf_[leaf] = smaller_bests_per_thread[smaller_best_idx];
Guolin Ke's avatar
Guolin Ke committed
429

430
431
  if (this->larger_leaf_splits_ != nullptr &&  this->larger_leaf_splits_->leaf_index() >= 0) {
    leaf = this->larger_leaf_splits_->leaf_index();
Guolin Ke's avatar
Guolin Ke committed
432
433
434
    auto larger_best_idx = ArrayArgs<SplitInfo>::ArgMax(larger_bests_per_thread);
    this->best_split_per_leaf_[leaf] = larger_bests_per_thread[larger_best_idx];
  }
Guolin Ke's avatar
Guolin Ke committed
435

Guolin Ke's avatar
Guolin Ke committed
436
  SplitInfo smaller_best_split, larger_best_split;
437
  smaller_best_split = this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()];
Guolin Ke's avatar
Guolin Ke committed
438
  // find local best split for larger leaf
439
440
  if (this->larger_leaf_splits_->leaf_index() >= 0) {
    larger_best_split = this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()];
Guolin Ke's avatar
Guolin Ke committed
441
442
443
  }

  // sync global best info
Guolin Ke's avatar
Guolin Ke committed
444
  SyncUpGlobalBestSplit(input_buffer_.data(), input_buffer_.data(), &smaller_best_split, &larger_best_split, this->config_->max_cat_threshold);
Guolin Ke's avatar
Guolin Ke committed
445
446

  // set best split
447
448
449
  this->best_split_per_leaf_[this->smaller_leaf_splits_->leaf_index()] = smaller_best_split;
  if (this->larger_leaf_splits_->leaf_index() >= 0) {
    this->best_split_per_leaf_[this->larger_leaf_splits_->leaf_index()] = larger_best_split;
Guolin Ke's avatar
Guolin Ke committed
450
451
452
  }
}

453
454
template <typename TREELEARNER_T>
void DataParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
455
  TREELEARNER_T::SplitInner(tree, best_Leaf, left_leaf, right_leaf, false);
456
  const SplitInfo& best_split_info = this->best_split_per_leaf_[best_Leaf];
Guolin Ke's avatar
Guolin Ke committed
457
458
459
  // need update global number of data in leaf
  global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
  global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
460
461
462
463
  // reset hist num bits according to global num data
  if (this->config_->use_quantized_grad) {
    this->gradient_discretizer_->template SetNumBitsInHistogramBin<true>(*left_leaf, *right_leaf, GetGlobalDataCountInLeaf(*left_leaf), GetGlobalDataCountInLeaf(*right_leaf));
  }
Guolin Ke's avatar
Guolin Ke committed
464
465
}

466
467
468
// instantiate template classes, otherwise linker cannot find the code
template class DataParallelTreeLearner<GPUTreeLearner>;
template class DataParallelTreeLearner<SerialTreeLearner>;
Guolin Ke's avatar
Guolin Ke committed
469
470

}  // namespace LightGBM