gpu_tree_learner.cpp 51.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2017 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
#ifdef USE_GPU
6

7
8
9
#include "gpu_tree_learner.h"

#include <LightGBM/bin.h>
10
11
#include <LightGBM/network.h>
#include <LightGBM/utils/array_args.h>
12

13
14
#include <algorithm>

15
#include "../io/dense_bin.hpp"
16
17
18
19
20

#define GPU_DEBUG 0

namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
21
22
GPUTreeLearner::GPUTreeLearner(const Config* config)
  :SerialTreeLearner(config) {
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
  use_bagging_ = false;
  Log::Info("This is the GPU trainer!!");
}

GPUTreeLearner::~GPUTreeLearner() {
  if (ptr_pinned_gradients_) {
    queue_.enqueue_unmap_buffer(pinned_gradients_, ptr_pinned_gradients_);
  }
  if (ptr_pinned_hessians_) {
    queue_.enqueue_unmap_buffer(pinned_hessians_, ptr_pinned_hessians_);
  }
  if (ptr_pinned_feature_masks_) {
    queue_.enqueue_unmap_buffer(pinned_feature_masks_, ptr_pinned_feature_masks_);
  }
}

void GPUTreeLearner::Init(const Dataset* train_data, bool is_constant_hessian) {
  // initialize SerialTreeLearner
  SerialTreeLearner::Init(train_data, is_constant_hessian);
  // some additional variables needed for GPU trainer
  num_feature_groups_ = train_data_->num_feature_groups();
  // Initialize GPU buffers and kernels
Guolin Ke's avatar
Guolin Ke committed
45
  InitGPU(config_->gpu_platform_id, config_->gpu_device_id);
46
47
48
49
50
}

// some functions used for debugging the GPU histogram construction
#if GPU_DEBUG > 0

51
52
void PrintHistograms(hist_t* h, size_t size) {
  double total_hess = 0;
53
  for (size_t i = 0; i < size; ++i) {
54
55
    printf("%03lu=%9.3g,%9.3g\t", i, GET_GRAD(h, i), GET_HESS(h, i));
    if ((i & 2) == 2)
56
        printf("\n");
57
    total_hess += GET_HESS(h, i);
58
  }
59
  printf("\nSum hessians: %9.3g\n", total_hess);
60
61
}

62
union Float_t {
63
64
65
66
67
68
    int64_t i;
    double f;
    static int64_t ulp_diff(Float_t a, Float_t b) {
      return abs(a.i - b.i);
    }
};
69

70

71
void CompareHistograms(hist_t* h1, hist_t* h2, size_t size, int feature_id) {
72
73
74
  size_t i;
  Float_t a, b;
  for (i = 0; i < size; ++i) {
75
76
    a.f = GET_GRAD(h1, i);
    b.f = GET_GRAD(h2, i);
77
78
    int32_t ulps = Float_t::ulp_diff(a, b);
    if (ulps > 0) {
79
      // printf("grad %g != %g (%d ULPs)\n", GET_GRAD(h1, i), GET_GRAD(h2, i), ulps);
80
81
      // goto err;
    }
82
83
    a.f = GET_HESS(h1, i);
    b.f = GET_HESS(h2, i);
84
    ulps = Float_t::ulp_diff(a, b);
85
86
87
    if (std::fabs(a.f - b.f) >= 1e-20) {
      printf("hessian %g != %g (%d ULPs)\n", GET_HESS(h1, i), GET_HESS(h2, i), ulps);
      goto err;
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
    }
  }
  return;
err:
  Log::Warning("Mismatched histograms found for feature %d at location %lu.", feature_id, i);
  std::cin.get();
  PrintHistograms(h1, size);
  printf("\n");
  PrintHistograms(h2, size);
  std::cin.get();
}
#endif

int GPUTreeLearner::GetNumWorkgroupsPerFeature(data_size_t leaf_num_data) {
  // we roughly want 256 workgroups per device, and we have num_dense_feature4_ feature tuples.
  // also guarantee that there are at least 2K examples per workgroup
  double x = 256.0 / num_dense_feature4_;
105
  int exp_workgroups_per_feature = static_cast<int>(ceil(log2(x)));
106
107
108
109
  double t = leaf_num_data / 1024.0;
  #if GPU_DEBUG >= 4
  printf("Computing histogram for %d examples and (%d * %d) feature groups\n", leaf_num_data, dword_features_, num_dense_feature4_);
  printf("We can have at most %d workgroups per feature4 for efficiency reasons.\n"
110
         "Best workgroup size per feature for full utilization is %d\n", static_cast<int>(ceil(t)), (1 << exp_workgroups_per_feature));
111
  #endif
112
  exp_workgroups_per_feature = std::min(exp_workgroups_per_feature, static_cast<int>(ceil(log(static_cast<double>(t))/log(2.0))));
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  if (exp_workgroups_per_feature < 0)
      exp_workgroups_per_feature = 0;
  if (exp_workgroups_per_feature > kMaxLogWorkgroupsPerFeature)
      exp_workgroups_per_feature = kMaxLogWorkgroupsPerFeature;
  // return 0;
  return exp_workgroups_per_feature;
}

void GPUTreeLearner::GPUHistogram(data_size_t leaf_num_data, bool use_all_features) {
  // we have already copied ordered gradients, ordered hessians and indices to GPU
  // decide the best number of workgroups working on one feature4 tuple
  // set work group size based on feature size
  // each 2^exp_workgroups_per_feature workgroups work on a feature4 tuple
  int exp_workgroups_per_feature = GetNumWorkgroupsPerFeature(leaf_num_data);
  int num_workgroups = (1 << exp_workgroups_per_feature) * num_dense_feature4_;
  if (num_workgroups > preallocd_max_num_wg_) {
    preallocd_max_num_wg_ = num_workgroups;
130
    Log::Info("Increasing preallocd_max_num_wg_ to %d for launching more workgroups", preallocd_max_num_wg_);
131
132
133
134
135
136
137
138
139
140
141
    device_subhistograms_.reset(new boost::compute::vector<char>(
                              preallocd_max_num_wg_ * dword_features_ * device_bin_size_ * hist_bin_entry_sz_, ctx_));
    // we need to refresh the kernel arguments after reallocating
    for (int i = 0; i <= kMaxLogWorkgroupsPerFeature; ++i) {
      // The only argument that needs to be changed later is num_data_
      histogram_kernels_[i].set_arg(7, *device_subhistograms_);
      histogram_allfeats_kernels_[i].set_arg(7, *device_subhistograms_);
      histogram_fulldata_kernels_[i].set_arg(7, *device_subhistograms_);
    }
  }
  #if GPU_DEBUG >= 4
142
  printf("Setting exp_workgroups_per_feature to %d, using %u work groups\n", exp_workgroups_per_feature, num_workgroups);
143
144
  printf("Constructing histogram with %d examples\n", leaf_num_data);
  #endif
145

146
147
148
149
150
151
  // the GPU kernel will process all features in one call, and each
  // 2^exp_workgroups_per_feature (compile time constant) workgroup will
  // process one feature4 tuple

  if (use_all_features) {
    histogram_allfeats_kernels_[exp_workgroups_per_feature].set_arg(4, leaf_num_data);
152
  } else {
153
154
155
156
157
158
159
    histogram_kernels_[exp_workgroups_per_feature].set_arg(4, leaf_num_data);
  }
  // for the root node, indices are not copied
  if (leaf_num_data != num_data_) {
    indices_future_.wait();
  }
  // for constant hessian, hessians are not copied except for the root node
160
  if (!share_state_->is_constant_hessian) {
161
162
163
164
165
166
167
168
    hessians_future_.wait();
  }
  gradients_future_.wait();
  // there will be 2^exp_workgroups_per_feature = num_workgroups / num_dense_feature4 sub-histogram per feature4
  // and we will launch num_feature workgroups for this kernel
  // will launch threads for all features
  // the queue should be asynchrounous, and we will can WaitAndGetHistograms() before we start processing dense feature groups
  if (leaf_num_data == num_data_) {
169
170
    kernel_wait_obj_ = boost::compute::wait_list(
      queue_.enqueue_1d_range_kernel(histogram_fulldata_kernels_[exp_workgroups_per_feature], 0, num_workgroups * 256, 256));
171
  } else {
172
173
    if (use_all_features) {
      kernel_wait_obj_ = boost::compute::wait_list(
174
        queue_.enqueue_1d_range_kernel(histogram_allfeats_kernels_[exp_workgroups_per_feature], 0, num_workgroups * 256, 256));
175
    } else {
176
      kernel_wait_obj_ = boost::compute::wait_list(
177
        queue_.enqueue_1d_range_kernel(histogram_kernels_[exp_workgroups_per_feature], 0, num_workgroups * 256, 256));
178
179
180
181
182
    }
  }
  // copy the results asynchronously. Size depends on if double precision is used
  size_t output_size = num_dense_feature4_ * dword_features_ * device_bin_size_ * hist_bin_entry_sz_;
  boost::compute::event histogram_wait_event;
183
184
  host_histogram_outputs_ = reinterpret_cast<void*>(queue_.enqueue_map_buffer_async(
    device_histogram_outputs_, boost::compute::command_queue::map_read, 0, output_size, histogram_wait_event, kernel_wait_obj_));
185
186
187
188
189
  // we will wait for this object in WaitAndGetHistograms
  histograms_wait_obj_ = boost::compute::wait_list(histogram_wait_event);
}

template <typename HistType>
190
void GPUTreeLearner::WaitAndGetHistograms(hist_t* histograms) {
191
  HistType* hist_outputs = reinterpret_cast<HistType*>(host_histogram_outputs_);
192
193
194
  // when the output is ready, the computation is done
  histograms_wait_obj_.wait();
  #pragma omp parallel for schedule(static)
195
  for (int i = 0; i < num_dense_feature_groups_; ++i) {
196
197
198
199
    if (!feature_masks_[i]) {
      continue;
    }
    int dense_group_index = dense_feature_group_map_[i];
200
    auto old_histogram_array = histograms + train_data_->GroupBinBoundary(dense_group_index) * 2;
201
    int bin_size = train_data_->FeatureGroupNumBin(dense_group_index);
202
203
    if (device_bin_mults_[i] == 1) {
      for (int j = 0; j < bin_size; ++j) {
204
205
        GET_GRAD(old_histogram_array, j) = GET_GRAD(hist_outputs, i * device_bin_size_+ j);
        GET_HESS(old_histogram_array, j) = GET_HESS(hist_outputs, i * device_bin_size_+ j);
206
      }
207
    } else {
208
209
210
211
212
      // values of this feature has been redistributed to multiple bins; need a reduction here
      int ind = 0;
      for (int j = 0; j < bin_size; ++j) {
        double sum_g = 0.0, sum_h = 0.0;
        for (int k = 0; k < device_bin_mults_[i]; ++k) {
213
214
          sum_g += GET_GRAD(hist_outputs, i * device_bin_size_+ ind);
          sum_h += GET_HESS(hist_outputs, i * device_bin_size_+ ind);
215
216
          ind++;
        }
217
218
        GET_GRAD(old_histogram_array, j) = sum_g;
        GET_HESS(old_histogram_array, j) = sum_h;
219
220
221
222
223
224
225
226
227
      }
    }
  }
  queue_.enqueue_unmap_buffer(device_histogram_outputs_, host_histogram_outputs_);
}

void GPUTreeLearner::AllocateGPUMemory() {
  num_dense_feature_groups_ = 0;
  for (int i = 0; i < num_feature_groups_; ++i) {
228
    if (!train_data_->IsMultiGroup(i)) {
229
230
231
232
233
234
235
236
237
238
239
240
241
242
      num_dense_feature_groups_++;
    }
  }
  // how many feature-group tuples we have
  num_dense_feature4_ = (num_dense_feature_groups_ + (dword_features_ - 1)) / dword_features_;
  // leave some safe margin for prefetching
  // 256 work-items per workgroup. Each work-item prefetches one tuple for that feature
  int allocated_num_data_ = num_data_ + 256 * (1 << kMaxLogWorkgroupsPerFeature);
  // clear sparse/dense maps
  dense_feature_group_map_.clear();
  device_bin_mults_.clear();
  sparse_feature_group_map_.clear();
  // do nothing if no features can be processed on GPU
  if (!num_dense_feature_groups_) {
Lingyi Hu's avatar
Lingyi Hu committed
243
    Log::Warning("GPU acceleration is disabled because no non-trivial dense features can be found");
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    return;
  }
  // allocate memory for all features (FIXME: 4 GB barrier on some devices, need to split to multiple buffers)
  device_features_.reset();
  device_features_ = std::unique_ptr<boost::compute::vector<Feature4>>(new boost::compute::vector<Feature4>(num_dense_feature4_ * num_data_, ctx_));
  // unpin old buffer if necessary before destructing them
  if (ptr_pinned_gradients_) {
    queue_.enqueue_unmap_buffer(pinned_gradients_, ptr_pinned_gradients_);
  }
  if (ptr_pinned_hessians_) {
    queue_.enqueue_unmap_buffer(pinned_hessians_, ptr_pinned_hessians_);
  }
  if (ptr_pinned_feature_masks_) {
    queue_.enqueue_unmap_buffer(pinned_feature_masks_, ptr_pinned_feature_masks_);
  }
259
  // make ordered_gradients and hessians larger (including extra room for prefetching), and pin them
260
261
  ordered_gradients_.reserve(allocated_num_data_);
  ordered_hessians_.reserve(allocated_num_data_);
262
263
264
  pinned_gradients_ = boost::compute::buffer();  // deallocate
  pinned_gradients_ = boost::compute::buffer(ctx_, allocated_num_data_ * sizeof(score_t),
                                             boost::compute::memory_object::read_write | boost::compute::memory_object::use_host_ptr,
265
                                             ordered_gradients_.data());
266
  ptr_pinned_gradients_ = queue_.enqueue_map_buffer(pinned_gradients_, boost::compute::command_queue::map_write_invalidate_region,
267
                                                    0, allocated_num_data_ * sizeof(score_t));
268
269
270
  pinned_hessians_ = boost::compute::buffer();  // deallocate
  pinned_hessians_  = boost::compute::buffer(ctx_, allocated_num_data_ * sizeof(score_t),
                                             boost::compute::memory_object::read_write | boost::compute::memory_object::use_host_ptr,
271
                                             ordered_hessians_.data());
272
  ptr_pinned_hessians_ = queue_.enqueue_map_buffer(pinned_hessians_, boost::compute::command_queue::map_write_invalidate_region,
273
                                                   0, allocated_num_data_ * sizeof(score_t));
274
275
  // allocate space for gradients and hessians on device
  // we will copy gradients and hessians in after ordered_gradients_ and ordered_hessians_ are constructed
276
277
  device_gradients_ = boost::compute::buffer();  // deallocate
  device_gradients_ = boost::compute::buffer(ctx_, allocated_num_data_ * sizeof(score_t),
278
                      boost::compute::memory_object::read_only, nullptr);
279
280
  device_hessians_ = boost::compute::buffer();  // deallocate
  device_hessians_  = boost::compute::buffer(ctx_, allocated_num_data_ * sizeof(score_t),
281
282
283
                      boost::compute::memory_object::read_only, nullptr);
  // allocate feature mask, for disabling some feature-groups' histogram calculation
  feature_masks_.resize(num_dense_feature4_ * dword_features_);
284
285
  device_feature_masks_ = boost::compute::buffer();  // deallocate
  device_feature_masks_ = boost::compute::buffer(ctx_, num_dense_feature4_ * dword_features_,
286
                          boost::compute::memory_object::read_only, nullptr);
287
288
  pinned_feature_masks_ = boost::compute::buffer(ctx_, num_dense_feature4_ * dword_features_,
                                             boost::compute::memory_object::read_write | boost::compute::memory_object::use_host_ptr,
289
290
291
292
293
294
295
296
297
                                             feature_masks_.data());
  ptr_pinned_feature_masks_ = queue_.enqueue_map_buffer(pinned_feature_masks_, boost::compute::command_queue::map_write_invalidate_region,
                                                        0, num_dense_feature4_ * dword_features_);
  memset(ptr_pinned_feature_masks_, 0, num_dense_feature4_ * dword_features_);
  // copy indices to the device
  device_data_indices_.reset();
  device_data_indices_ = std::unique_ptr<boost::compute::vector<data_size_t>>(new boost::compute::vector<data_size_t>(allocated_num_data_, ctx_));
  boost::compute::fill(device_data_indices_->begin(), device_data_indices_->end(), 0, queue_);
  // histogram bin entry size depends on the precision (single/double)
298
  hist_bin_entry_sz_ = config_->gpu_use_dp ? sizeof(hist_t) * 2 : sizeof(gpu_hist_t) * 2;
299
300
301
302
303
304
305
306
307
308
309
310
311
312
  Log::Info("Size of histogram bin entry: %d", hist_bin_entry_sz_);
  // create output buffer, each feature has a histogram with device_bin_size_ bins,
  // each work group generates a sub-histogram of dword_features_ features.
  if (!device_subhistograms_) {
    // only initialize once here, as this will not need to change when ResetTrainingData() is called
    device_subhistograms_ = std::unique_ptr<boost::compute::vector<char>>(new boost::compute::vector<char>(
                              preallocd_max_num_wg_ * dword_features_ * device_bin_size_ * hist_bin_entry_sz_, ctx_));
  }
  // create atomic counters for inter-group coordination
  sync_counters_.reset();
  sync_counters_ = std::unique_ptr<boost::compute::vector<int>>(new boost::compute::vector<int>(
                    num_dense_feature4_, ctx_));
  boost::compute::fill(sync_counters_->begin(), sync_counters_->end(), 0, queue_);
  // The output buffer is allocated to host directly, to overlap compute and data transfer
313
  device_histogram_outputs_ = boost::compute::buffer();  // deallocate
314
  device_histogram_outputs_ = boost::compute::buffer(ctx_, num_dense_feature4_ * dword_features_ * device_bin_size_ * hist_bin_entry_sz_,
315
316
                           boost::compute::memory_object::write_only | boost::compute::memory_object::alloc_host_ptr, nullptr);
  // find the dense feature-groups and group then into Feature4 data structure (several feature-groups packed into 4 bytes)
317
318
319
  int k = 0, copied_feature4 = 0;
  std::vector<int> dense_dword_ind(dword_features_);
  for (int i = 0; i < num_feature_groups_; ++i) {
320
    // looking for dword_features_ non-sparse feature-groups
321
    if (!train_data_->IsMultiGroup(i)) {
322
      dense_dword_ind[k] = i;
323
      // decide if we need to redistribute the bin
324
      double t = device_bin_size_ / static_cast<double>(train_data_->FeatureGroupNumBin(i));
325
      // multiplier must be a power of 2
326
      device_bin_mults_.push_back(static_cast<int>(round(pow(2, floor(log2(t))))));
327
328
329
330
331
      // device_bin_mults_.push_back(1);
      #if GPU_DEBUG >= 1
      printf("feature-group %d using multiplier %d\n", i, device_bin_mults_.back());
      #endif
      k++;
332
    } else {
333
334
      sparse_feature_group_map_.push_back(i);
    }
335
    // found
336
337
338
    if (k == dword_features_) {
      k = 0;
      for (int j = 0; j < dword_features_; ++j) {
339
        dense_feature_group_map_.push_back(dense_dword_ind[j]);
340
341
342
343
344
345
346
      }
      copied_feature4++;
    }
  }
  // for data transfer time
  auto start_time = std::chrono::steady_clock::now();
  // Now generate new data structure feature4, and copy data to the device
347
  int nthreads = std::min(omp_get_max_threads(), static_cast<int>(dense_feature_group_map_.size()) / dword_features_);
348
349
350
351
352
353
  nthreads = std::max(nthreads, 1);
  std::vector<Feature4*> host4_vecs(nthreads);
  std::vector<boost::compute::buffer> host4_bufs(nthreads);
  std::vector<Feature4*> host4_ptrs(nthreads);
  // preallocate arrays for all threads, and pin them
  for (int i = 0; i < nthreads; ++i) {
354
    host4_vecs[i] = reinterpret_cast<Feature4*>(boost::alignment::aligned_alloc(4096, num_data_ * sizeof(Feature4)));
355
356
    host4_bufs[i] = boost::compute::buffer(ctx_, num_data_ * sizeof(Feature4),
                    boost::compute::memory_object::read_write | boost::compute::memory_object::use_host_ptr,
357
                    host4_vecs[i]);
358
359
    host4_ptrs[i] = reinterpret_cast<Feature4*>(queue_.enqueue_map_buffer(host4_bufs[i], boost::compute::command_queue::map_write_invalidate_region,
                    0, num_data_ * sizeof(Feature4)));
360
361
362
  }
  // building Feature4 bundles; each thread handles dword_features_ features
  #pragma omp parallel for schedule(static)
363
  for (int i = 0; i < static_cast<int>(dense_feature_group_map_.size() / dword_features_); ++i) {
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
    int tid = omp_get_thread_num();
    Feature4* host4 = host4_ptrs[tid];
    auto dense_ind = dense_feature_group_map_.begin() + i * dword_features_;
    auto dev_bin_mult = device_bin_mults_.begin() + i * dword_features_;
    #if GPU_DEBUG >= 1
    printf("Copying feature group ");
    for (int l = 0; l < dword_features_; ++l) {
      printf("%d ", dense_ind[l]);
    }
    printf("to devices\n");
    #endif
    if (dword_features_ == 8) {
      // one feature datapoint is 4 bits
      BinIterator* bin_iters[8];
      for (int s_idx = 0; s_idx < 8; ++s_idx) {
        bin_iters[s_idx] = train_data_->FeatureGroupIterator(dense_ind[s_idx]);
380
        if (dynamic_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[s_idx]) == 0) {
381
          Log::Fatal("GPU tree learner assumes that all bins are Dense4bitsBin when num_bin <= 16, but feature %d is not", dense_ind[s_idx]);
382
383
384
        }
      }
      // this guarantees that the RawGet() function is inlined, rather than using virtual function dispatching
385
386
387
388
389
390
391
392
393
      DenseBinIterator<uint8_t, true> iters[8] = {
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[0]),
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[1]),
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[2]),
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[3]),
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[4]),
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[5]),
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[6]),
        *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iters[7])};
394
      for (int j = 0; j < num_data_; ++j) {
395
        host4[j].s[0] = (uint8_t)((iters[0].RawGet(j) * dev_bin_mult[0] + ((j+0) & (dev_bin_mult[0] - 1)))
396
                      |((iters[1].RawGet(j) * dev_bin_mult[1] + ((j+1) & (dev_bin_mult[1] - 1))) << 4));
397
        host4[j].s[1] = (uint8_t)((iters[2].RawGet(j) * dev_bin_mult[2] + ((j+2) & (dev_bin_mult[2] - 1)))
398
                      |((iters[3].RawGet(j) * dev_bin_mult[3] + ((j+3) & (dev_bin_mult[3] - 1))) << 4));
399
        host4[j].s[2] = (uint8_t)((iters[4].RawGet(j) * dev_bin_mult[4] + ((j+4) & (dev_bin_mult[4] - 1)))
400
                      |((iters[5].RawGet(j) * dev_bin_mult[5] + ((j+5) & (dev_bin_mult[5] - 1))) << 4));
401
        host4[j].s[3] = (uint8_t)((iters[6].RawGet(j) * dev_bin_mult[6] + ((j+6) & (dev_bin_mult[6] - 1)))
402
                      |((iters[7].RawGet(j) * dev_bin_mult[7] + ((j+7) & (dev_bin_mult[7] - 1))) << 4));
403
      }
404
    } else if (dword_features_ == 4) {
405
406
407
408
      // one feature datapoint is one byte
      for (int s_idx = 0; s_idx < 4; ++s_idx) {
        BinIterator* bin_iter = train_data_->FeatureGroupIterator(dense_ind[s_idx]);
        // this guarantees that the RawGet() function is inlined, rather than using virtual function dispatching
409
        if (dynamic_cast<DenseBinIterator<uint8_t, false>*>(bin_iter) != 0) {
410
          // Dense bin
411
          DenseBinIterator<uint8_t, false> iter = *static_cast<DenseBinIterator<uint8_t, false>*>(bin_iter);
412
          for (int j = 0; j < num_data_; ++j) {
413
            host4[j].s[s_idx] = (uint8_t)(iter.RawGet(j) * dev_bin_mult[s_idx] + ((j+s_idx) & (dev_bin_mult[s_idx] - 1)));
414
          }
415
        } else if (dynamic_cast<DenseBinIterator<uint8_t, true>*>(bin_iter) != 0) {
416
          // Dense 4-bit bin
417
          DenseBinIterator<uint8_t, true> iter = *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iter);
418
          for (int j = 0; j < num_data_; ++j) {
419
            host4[j].s[s_idx] = (uint8_t)(iter.RawGet(j) * dev_bin_mult[s_idx] + ((j+s_idx) & (dev_bin_mult[s_idx] - 1)));
420
          }
421
        } else {
422
          Log::Fatal("Bug in GPU tree builder: only DenseBin and Dense4bitsBin are supported");
423
424
        }
      }
425
    } else {
426
      Log::Fatal("Bug in GPU tree builder: dword_features_ can only be 4 or 8");
427
    }
Vladimir's avatar
Vladimir committed
428
    #pragma omp critical
429
430
431
    queue_.enqueue_write_buffer(device_features_->get_buffer(),
                        i * num_data_ * sizeof(Feature4), num_data_ * sizeof(Feature4), host4);
    #if GPU_DEBUG >= 1
432
    printf("first example of feature-group tuple is: %d %d %d %d\n", host4[0].s[0], host4[0].s[1], host4[0].s[2], host4[0].s[3]);
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    printf("Feature-groups copied to device with multipliers ");
    for (int l = 0; l < dword_features_; ++l) {
      printf("%d ", dev_bin_mult[l]);
    }
    printf("\n");
    #endif
  }
  // working on the remaining (less than dword_features_) feature groups
  if (k != 0) {
    Feature4* host4 = host4_ptrs[0];
    if (dword_features_ == 8) {
      memset(host4, 0, num_data_ * sizeof(Feature4));
    }
    #if GPU_DEBUG >= 1
    printf("%d features left\n", k);
    #endif
449
    for (int i = 0; i < k; ++i) {
450
      if (dword_features_ == 8) {
451
        BinIterator* bin_iter = train_data_->FeatureGroupIterator(dense_dword_ind[i]);
452
453
        if (dynamic_cast<DenseBinIterator<uint8_t, true>*>(bin_iter) != 0) {
          DenseBinIterator<uint8_t, true> iter = *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iter);
454
455
          #pragma omp parallel for schedule(static)
          for (int j = 0; j < num_data_; ++j) {
456
            host4[j].s[i >> 1] |= (uint8_t)((iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i]
457
458
459
                                + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1)))
                               << ((i & 1) << 2));
          }
460
        } else {
461
          Log::Fatal("GPU tree learner assumes that all bins are Dense4bitsBin when num_bin <= 16, but feature %d is not", dense_dword_ind[i]);
462
        }
463
      } else if (dword_features_ == 4) {
464
        BinIterator* bin_iter = train_data_->FeatureGroupIterator(dense_dword_ind[i]);
465
466
        if (dynamic_cast<DenseBinIterator<uint8_t, false>*>(bin_iter) != 0) {
          DenseBinIterator<uint8_t, false> iter = *static_cast<DenseBinIterator<uint8_t, false>*>(bin_iter);
467
468
          #pragma omp parallel for schedule(static)
          for (int j = 0; j < num_data_; ++j) {
469
            host4[j].s[i] = (uint8_t)(iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i]
470
                          + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1)));
471
          }
472
473
        } else if (dynamic_cast<DenseBinIterator<uint8_t, true>*>(bin_iter) != 0) {
          DenseBinIterator<uint8_t, true> iter = *static_cast<DenseBinIterator<uint8_t, true>*>(bin_iter);
474
475
          #pragma omp parallel for schedule(static)
          for (int j = 0; j < num_data_; ++j) {
476
            host4[j].s[i] = (uint8_t)(iter.RawGet(j) * device_bin_mults_[copied_feature4 * dword_features_ + i]
477
                          + ((j+i) & (device_bin_mults_[copied_feature4 * dword_features_ + i] - 1)));
478
          }
479
        } else {
480
          Log::Fatal("BUG in GPU tree builder: only DenseBin and Dense4bitsBin are supported");
481
        }
482
      } else {
483
        Log::Fatal("Bug in GPU tree builder: dword_features_ can only be 4 or 8");
484
485
486
487
488
489
      }
    }
    // fill the leftover features
    if (dword_features_ == 8) {
      #pragma omp parallel for schedule(static)
      for (int j = 0; j < num_data_; ++j) {
490
        for (int i = k; i < dword_features_; ++i) {
491
          // fill this empty feature with some "random" value
492
          host4[j].s[i >> 1] |= (uint8_t)((j & 0xf) << ((i & 1) << 2));
493
494
        }
      }
495
    } else if (dword_features_ == 4) {
496
497
      #pragma omp parallel for schedule(static)
      for (int j = 0; j < num_data_; ++j) {
498
        for (int i = k; i < dword_features_; ++i) {
499
          // fill this empty feature with some "random" value
500
          host4[j].s[i] = (uint8_t)j;
501
502
503
504
505
506
507
508
509
        }
      }
    }
    // copying the last 1 to (dword_features - 1) feature-groups in the last tuple
    queue_.enqueue_write_buffer(device_features_->get_buffer(),
                        (num_dense_feature4_ - 1) * num_data_ * sizeof(Feature4), num_data_ * sizeof(Feature4), host4);
    #if GPU_DEBUG >= 1
    printf("Last features copied to device\n");
    #endif
510
511
    for (int i = 0; i < k; ++i) {
      dense_feature_group_map_.push_back(dense_dword_ind[i]);
512
513
514
515
516
517
518
519
520
521
    }
  }
  // deallocate pinned space for feature copying
  for (int i = 0; i < nthreads; ++i) {
      queue_.enqueue_unmap_buffer(host4_bufs[i], host4_ptrs[i]);
      host4_bufs[i] = boost::compute::buffer();
      boost::alignment::aligned_free(host4_vecs[i]);
  }
  // data transfer time
  std::chrono::duration<double, std::milli> end_time = std::chrono::steady_clock::now() - start_time;
522
523
  Log::Info("%d dense feature groups (%.2f MB) transferred to GPU in %f secs. %d sparse feature groups",
            dense_feature_group_map_.size(), ((dense_feature_group_map_.size() + (dword_features_ - 1)) / dword_features_) * num_data_ * sizeof(Feature4) / (1024.0 * 1024.0),
524
525
526
            end_time * 1e-3, sparse_feature_group_map_.size());
  #if GPU_DEBUG >= 1
  printf("Dense feature group list (size %lu): ", dense_feature_group_map_.size());
527
  for (int i = 0; i < num_dense_feature_groups_; ++i) {
528
529
530
531
    printf("%d ", dense_feature_group_map_[i]);
  }
  printf("\n");
  printf("Sparse feature group list (size %lu): ", sparse_feature_group_map_.size());
532
  for (int i = 0; i < num_feature_groups_ - num_dense_feature_groups_; ++i) {
533
534
535
536
537
538
    printf("%d ", sparse_feature_group_map_[i]);
  }
  printf("\n");
  #endif
}

539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
std::string GPUTreeLearner::GetBuildLog(const std::string &opts) {
  boost::compute::program program = boost::compute::program::create_with_source(kernel_source_, ctx_);
  try {
    program.build(opts);
  }
  catch (boost::compute::opencl_error &e) {
    auto error_code = e.error_code();
    std::string log("No log available.\n");
    // for other types of failure, build log might not be available; program.build_log() can crash
    if (error_code == CL_INVALID_PROGRAM || error_code == CL_BUILD_PROGRAM_FAILURE) {
      try {
        log = program.build_log();
      }
      catch(...) {
        // Something bad happened. Just return "No log available."
      }
    }
    return log;
  }
  // build is okay, log may contain warnings
  return program.build_log();
}

562
563
564
565
566
567
568
569
570
571
572
573
void GPUTreeLearner::BuildGPUKernels() {
  Log::Info("Compiling OpenCL Kernel with %d bins...", device_bin_size_);
  // destroy any old kernels
  histogram_kernels_.clear();
  histogram_allfeats_kernels_.clear();
  histogram_fulldata_kernels_.clear();
  // create OpenCL kernels for different number of workgroups per feature
  histogram_kernels_.resize(kMaxLogWorkgroupsPerFeature+1);
  histogram_allfeats_kernels_.resize(kMaxLogWorkgroupsPerFeature+1);
  histogram_fulldata_kernels_.resize(kMaxLogWorkgroupsPerFeature+1);
  // currently we don't use constant memory
  int use_constants = 0;
574
  OMP_INIT_EX();
575
576
  #pragma omp parallel for schedule(guided)
  for (int i = 0; i <= kMaxLogWorkgroupsPerFeature; ++i) {
577
    OMP_LOOP_EX_BEGIN();
578
579
    boost::compute::program program;
    std::ostringstream opts;
580
    // compile the GPU kernel depending if double precision is used, constant hessian is used, etc.
581
    opts << " -D POWER_FEATURE_WORKGROUPS=" << i
Guolin Ke's avatar
Guolin Ke committed
582
         << " -D USE_CONSTANT_BUF=" << use_constants << " -D USE_DP_FLOAT=" << int(config_->gpu_use_dp)
583
         << " -D CONST_HESSIAN=" << int(share_state_->is_constant_hessian)
584
         << " -cl-mad-enable -cl-no-signed-zeros -cl-fast-relaxed-math";
585
586
587
588
589
590
591
592
    #if GPU_DEBUG >= 1
    std::cout << "Building GPU kernels with options: " << opts.str() << std::endl;
    #endif
    // kernel with indices in an array
    try {
      program = boost::compute::program::build_with_source(kernel_source_, ctx_, opts.str());
    }
    catch (boost::compute::opencl_error &e) {
593
594
595
596
597
      #pragma omp critical
      {
        std::cerr << "Build Options:" << opts.str() << std::endl;
        std::cerr << "Build Log:" << std::endl << GetBuildLog(opts.str()) << std::endl;
        Log::Fatal("Cannot build GPU program: %s", e.what());
598
599
600
      }
    }
    histogram_kernels_[i] = program.create_kernel(kernel_name_);
601

602
603
604
605
606
607
    // kernel with all features enabled, with elimited branches
    opts << " -D ENABLE_ALL_FEATURES=1";
    try {
      program = boost::compute::program::build_with_source(kernel_source_, ctx_, opts.str());
    }
    catch (boost::compute::opencl_error &e) {
608
609
610
611
612
      #pragma omp critical
      {
        std::cerr << "Build Options:" << opts.str() << std::endl;
        std::cerr << "Build Log:" << std::endl << GetBuildLog(opts.str()) << std::endl;
        Log::Fatal("Cannot build GPU program: %s", e.what());
613
614
615
616
617
618
619
620
621
622
      }
    }
    histogram_allfeats_kernels_[i] = program.create_kernel(kernel_name_);

    // kernel with all data indices (for root node, and assumes that root node always uses all features)
    opts << " -D IGNORE_INDICES=1";
    try {
      program = boost::compute::program::build_with_source(kernel_source_, ctx_, opts.str());
    }
    catch (boost::compute::opencl_error &e) {
623
624
625
626
627
      #pragma omp critical
      {
        std::cerr << "Build Options:" << opts.str() << std::endl;
        std::cerr << "Build Log:" << std::endl << GetBuildLog(opts.str()) << std::endl;
        Log::Fatal("Cannot build GPU program: %s", e.what());
628
629
630
      }
    }
    histogram_fulldata_kernels_[i] = program.create_kernel(kernel_name_);
631
    OMP_LOOP_EX_END();
632
  }
633
  OMP_THROW_EX();
634
635
636
637
638
639
640
641
642
643
  Log::Info("GPU programs have been built");
}

void GPUTreeLearner::SetupKernelArguments() {
  // do nothing if no features can be processed on GPU
  if (!num_dense_feature_groups_) {
    return;
  }
  for (int i = 0; i <= kMaxLogWorkgroupsPerFeature; ++i) {
    // The only argument that needs to be changed later is num_data_
644
    if (share_state_->is_constant_hessian) {
645
      // hessian is passed as a parameter, but it is not available now.
646
647
648
649
650
651
652
653
654
655
      // hessian will be set in BeforeTrain()
      histogram_kernels_[i].set_args(*device_features_, device_feature_masks_, num_data_,
                                         *device_data_indices_, num_data_, device_gradients_, 0.0f,
                                         *device_subhistograms_, *sync_counters_, device_histogram_outputs_);
      histogram_allfeats_kernels_[i].set_args(*device_features_, device_feature_masks_, num_data_,
                                         *device_data_indices_, num_data_, device_gradients_, 0.0f,
                                         *device_subhistograms_, *sync_counters_, device_histogram_outputs_);
      histogram_fulldata_kernels_[i].set_args(*device_features_, device_feature_masks_, num_data_,
                                          *device_data_indices_, num_data_, device_gradients_, 0.0f,
                                          *device_subhistograms_, *sync_counters_, device_histogram_outputs_);
656
    } else {
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
      histogram_kernels_[i].set_args(*device_features_, device_feature_masks_, num_data_,
                                         *device_data_indices_, num_data_, device_gradients_, device_hessians_,
                                         *device_subhistograms_, *sync_counters_, device_histogram_outputs_);
      histogram_allfeats_kernels_[i].set_args(*device_features_, device_feature_masks_, num_data_,
                                         *device_data_indices_, num_data_, device_gradients_, device_hessians_,
                                         *device_subhistograms_, *sync_counters_, device_histogram_outputs_);
      histogram_fulldata_kernels_[i].set_args(*device_features_, device_feature_masks_, num_data_,
                                          *device_data_indices_, num_data_, device_gradients_, device_hessians_,
                                          *device_subhistograms_, *sync_counters_, device_histogram_outputs_);
    }
  }
}

void GPUTreeLearner::InitGPU(int platform_id, int device_id) {
  // Get the max bin size, used for selecting best GPU kernel
  max_num_bin_ = 0;
  #if GPU_DEBUG >= 1
  printf("bin size: ");
  #endif
  for (int i = 0; i < num_feature_groups_; ++i) {
677
678
679
    if (train_data_->IsMultiGroup(i)) {
      continue;
    }
680
681
682
683
684
685
686
687
688
689
690
691
    #if GPU_DEBUG >= 1
    printf("%d, ", train_data_->FeatureGroupNumBin(i));
    #endif
    max_num_bin_ = std::max(max_num_bin_, train_data_->FeatureGroupNumBin(i));
  }
  #if GPU_DEBUG >= 1
  printf("\n");
  #endif
  // initialize GPU
  dev_ = boost::compute::system::default_device();
  if (platform_id >= 0 && device_id >= 0) {
    const std::vector<boost::compute::platform> platforms = boost::compute::system::platforms();
692
    if (static_cast<int>(platforms.size()) > platform_id) {
693
      const std::vector<boost::compute::device> platform_devices = platforms[platform_id].devices();
694
      if (static_cast<int>(platform_devices.size()) > device_id) {
695
696
        Log::Info("Using requested OpenCL platform %d device %d", platform_id, device_id);
        dev_ = platform_devices[device_id];
697
698
699
      }
    }
  }
700
701
  // determine which kernel to use based on the max number of bins
  if (max_num_bin_ <= 16) {
Guolin Ke's avatar
Guolin Ke committed
702
703
    // the +9 skips extra characters ")", newline, "#endif" and newline at the beginning
    kernel_source_ = kernel16_src_ + 9;
704
705
706
    kernel_name_ = "histogram16";
    device_bin_size_ = 16;
    dword_features_ = 8;
707
  } else if (max_num_bin_ <= 64) {
Guolin Ke's avatar
Guolin Ke committed
708
709
    // the +9 skips extra characters ")", newline, "#endif" and newline at the beginning
    kernel_source_ = kernel64_src_ + 9;
710
711
712
    kernel_name_ = "histogram64";
    device_bin_size_ = 64;
    dword_features_ = 4;
713
  } else if (max_num_bin_ <= 256) {
Guolin Ke's avatar
Guolin Ke committed
714
715
    // the +9 skips extra characters ")", newline, "#endif" and newline at the beginning
    kernel_source_ = kernel256_src_ + 9;
716
717
718
    kernel_name_ = "histogram256";
    device_bin_size_ = 256;
    dword_features_ = 4;
719
  } else {
720
721
    Log::Fatal("bin size %d cannot run on GPU", max_num_bin_);
  }
722
  if (max_num_bin_ == 65) {
723
724
    Log::Warning("Setting max_bin to 63 is sugguested for best performance");
  }
725
  if (max_num_bin_ == 17) {
726
727
728
729
730
731
732
733
734
735
736
    Log::Warning("Setting max_bin to 15 is sugguested for best performance");
  }
  ctx_ = boost::compute::context(dev_);
  queue_ = boost::compute::command_queue(ctx_, dev_);
  Log::Info("Using GPU Device: %s, Vendor: %s", dev_.name().c_str(), dev_.vendor().c_str());
  BuildGPUKernels();
  AllocateGPUMemory();
  // setup GPU kernel arguments after we allocating all the buffers
  SetupKernelArguments();
}

737
738
Tree* GPUTreeLearner::Train(const score_t* gradients, const score_t *hessians) {
  return SerialTreeLearner::Train(gradients, hessians);
739
740
}

741
742
void GPUTreeLearner::ResetTrainingDataInner(const Dataset* train_data, bool is_constant_hessian, bool reset_multi_val_bin) {
  SerialTreeLearner::ResetTrainingDataInner(train_data, is_constant_hessian, reset_multi_val_bin);
743
744
745
746
747
748
749
  num_feature_groups_ = train_data_->num_feature_groups();
  // GPU memory has to been reallocated because data may have been changed
  AllocateGPUMemory();
  // setup GPU kernel arguments after we allocating all the buffers
  SetupKernelArguments();
}

750
void GPUTreeLearner::ResetIsConstantHessian(bool is_constant_hessian) {
Nikita Titov's avatar
Nikita Titov committed
751
  if (is_constant_hessian != share_state_->is_constant_hessian) {
752
    SerialTreeLearner::ResetIsConstantHessian(is_constant_hessian);
Nikita Titov's avatar
Nikita Titov committed
753
754
    BuildGPUKernels();
    SetupKernelArguments();
755
756
757
  }
}

758
759
760
761
762
763
764
void GPUTreeLearner::BeforeTrain() {
  #if GPU_DEBUG >= 2
  printf("Copying intial full gradients and hessians to device\n");
  #endif
  // Copy initial full hessians and gradients to GPU.
  // We start copying as early as possible, instead of at ConstructHistogram().
  if (!use_bagging_ && num_dense_feature_groups_) {
765
    if (!share_state_->is_constant_hessian) {
766
      hessians_future_ = queue_.enqueue_write_buffer_async(device_hessians_, 0, num_data_ * sizeof(score_t), hessians_);
767
    } else {
768
      // setup hessian parameters only
769
      score_t const_hessian = hessians_[0];
770
771
772
773
774
775
776
      for (int i = 0; i <= kMaxLogWorkgroupsPerFeature; ++i) {
        // hessian is passed as a parameter
        histogram_kernels_[i].set_arg(6, const_hessian);
        histogram_allfeats_kernels_[i].set_arg(6, const_hessian);
        histogram_fulldata_kernels_[i].set_arg(6, const_hessian);
      }
    }
777
    gradients_future_ = queue_.enqueue_write_buffer_async(device_gradients_, 0, num_data_ * sizeof(score_t), gradients_);
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
  }

  SerialTreeLearner::BeforeTrain();

  // use bagging
  if (data_partition_->leaf_count(0) != num_data_ && num_dense_feature_groups_) {
    // On GPU, we start copying indices, gradients and hessians now, instead at ConstructHistogram()
    // copy used gradients and hessians to ordered buffer
    const data_size_t* indices = data_partition_->indices();
    data_size_t cnt = data_partition_->leaf_count(0);
    #if GPU_DEBUG > 0
    printf("Using bagging, examples count = %d\n", cnt);
    #endif
    // transfer the indices to GPU
    indices_future_ = boost::compute::copy_async(indices, indices + cnt, device_data_indices_->begin(), queue_);
793
    if (!share_state_->is_constant_hessian) {
794
795
796
797
798
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < cnt; ++i) {
        ordered_hessians_[i] = hessians_[indices[i]];
      }
      // transfer hessian to GPU
799
      hessians_future_ = queue_.enqueue_write_buffer_async(device_hessians_, 0, cnt * sizeof(score_t), ordered_hessians_.data());
800
    } else {
801
      // setup hessian parameters only
802
      score_t const_hessian = hessians_[indices[0]];
803
804
805
806
807
808
809
810
811
812
813
814
      for (int i = 0; i <= kMaxLogWorkgroupsPerFeature; ++i) {
        // hessian is passed as a parameter
        histogram_kernels_[i].set_arg(6, const_hessian);
        histogram_allfeats_kernels_[i].set_arg(6, const_hessian);
        histogram_fulldata_kernels_[i].set_arg(6, const_hessian);
      }
    }
    #pragma omp parallel for schedule(static)
    for (data_size_t i = 0; i < cnt; ++i) {
      ordered_gradients_[i] = gradients_[indices[i]];
    }
    // transfer gradients to GPU
815
    gradients_future_ = queue_.enqueue_write_buffer_async(device_gradients_, 0, cnt * sizeof(score_t), ordered_gradients_.data());
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
  }
}

bool GPUTreeLearner::BeforeFindBestSplit(const Tree* tree, int left_leaf, int right_leaf) {
  int smaller_leaf;
  data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
  data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
  // only have root
  if (right_leaf < 0) {
    smaller_leaf = -1;
  } else if (num_data_in_left_child < num_data_in_right_child) {
    smaller_leaf = left_leaf;
  } else {
    smaller_leaf = right_leaf;
  }

  // Copy indices, gradients and hessians as early as possible
  if (smaller_leaf >= 0 && num_dense_feature_groups_) {
    // only need to initialize for smaller leaf
    // Get leaf boundary
    const data_size_t* indices = data_partition_->indices();
    data_size_t begin = data_partition_->leaf_begin(smaller_leaf);
    data_size_t end = begin + data_partition_->leaf_count(smaller_leaf);

    // copy indices to the GPU:
    #if GPU_DEBUG >= 2
    Log::Info("Copying indices, gradients and hessians to GPU...");
843
    printf("Indices size %d being copied (left = %d, right = %d)\n", end - begin, num_data_in_left_child, num_data_in_right_child);
844
845
846
    #endif
    indices_future_ = boost::compute::copy_async(indices + begin, indices + end, device_data_indices_->begin(), queue_);

847
    if (!share_state_->is_constant_hessian) {
848
849
850
851
852
      #pragma omp parallel for schedule(static)
      for (data_size_t i = begin; i < end; ++i) {
        ordered_hessians_[i - begin] = hessians_[indices[i]];
      }
      // copy ordered hessians to the GPU:
853
      hessians_future_ = queue_.enqueue_write_buffer_async(device_hessians_, 0, (end - begin) * sizeof(score_t), ptr_pinned_hessians_);
854
855
856
857
858
859
860
    }

    #pragma omp parallel for schedule(static)
    for (data_size_t i = begin; i < end; ++i) {
      ordered_gradients_[i - begin] = gradients_[indices[i]];
    }
    // copy ordered gradients to the GPU:
861
    gradients_future_ = queue_.enqueue_write_buffer_async(device_gradients_, 0, (end - begin) * sizeof(score_t), ptr_pinned_gradients_);
862
863

    #if GPU_DEBUG >= 2
864
    Log::Info("Gradients/hessians/indices copied to device with size %d", end - begin);
865
866
867
868
869
870
871
872
    #endif
  }
  return SerialTreeLearner::BeforeFindBestSplit(tree, left_leaf, right_leaf);
}

bool GPUTreeLearner::ConstructGPUHistogramsAsync(
  const std::vector<int8_t>& is_feature_used,
  const data_size_t* data_indices, data_size_t num_data,
873
874
  const score_t* gradients, const score_t* hessians,
  score_t* ordered_gradients, score_t* ordered_hessians) {
875
876
877
878
879
880
881
  if (num_data <= 0) {
    return false;
  }
  // do nothing if no features can be processed on GPU
  if (!num_dense_feature_groups_) {
    return false;
  }
882

883
884
885
886
887
888
889
890
891
892
893
  // copy data indices if it is not null
  if (data_indices != nullptr && num_data != num_data_) {
    indices_future_ = boost::compute::copy_async(data_indices, data_indices + num_data, device_data_indices_->begin(), queue_);
  }
  // generate and copy ordered_gradients if gradients is not null
  if (gradients != nullptr) {
    if (num_data != num_data_) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        ordered_gradients[i] = gradients[data_indices[i]];
      }
894
      gradients_future_ = queue_.enqueue_write_buffer_async(device_gradients_, 0, num_data * sizeof(score_t), ptr_pinned_gradients_);
895
    } else {
896
      gradients_future_ = queue_.enqueue_write_buffer_async(device_gradients_, 0, num_data * sizeof(score_t), gradients);
897
898
899
    }
  }
  // generate and copy ordered_hessians if hessians is not null
900
  if (hessians != nullptr && !share_state_->is_constant_hessian) {
901
902
903
904
905
    if (num_data != num_data_) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data; ++i) {
        ordered_hessians[i] = hessians[data_indices[i]];
      }
906
      hessians_future_ = queue_.enqueue_write_buffer_async(device_hessians_, 0, num_data * sizeof(score_t), ptr_pinned_hessians_);
907
    } else {
908
      hessians_future_ = queue_.enqueue_write_buffer_async(device_hessians_, 0, num_data * sizeof(score_t), hessians);
909
910
911
912
    }
  }
  // converted indices in is_feature_used to feature-group indices
  std::vector<int8_t> is_feature_group_used(num_feature_groups_, 0);
913
  #pragma omp parallel for schedule(static, 1024) if (num_features_ >= 2048)
914
  for (int i = 0; i < num_features_; ++i) {
915
    if (is_feature_used[i]) {
916
917
918
919
920
      is_feature_group_used[train_data_->Feature2Group(i)] = 1;
    }
  }
  // construct the feature masks for dense feature-groups
  int used_dense_feature_groups = 0;
921
  #pragma omp parallel for schedule(static, 1024) reduction(+:used_dense_feature_groups) if (num_dense_feature_groups_ >= 2048)
922
923
924
925
  for (int i = 0; i < num_dense_feature_groups_; ++i) {
    if (is_feature_group_used[dense_feature_group_map_[i]]) {
      feature_masks_[i] = 1;
      ++used_dense_feature_groups;
926
    } else {
927
928
929
930
931
932
933
934
935
      feature_masks_[i] = 0;
    }
  }
  bool use_all_features = used_dense_feature_groups == num_dense_feature_groups_;
  // if no feature group is used, just return and do not use GPU
  if (used_dense_feature_groups == 0) {
    return false;
  }
#if GPU_DEBUG >= 1
936
  printf("Feature masks:\n");
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
  for (unsigned int i = 0; i < feature_masks_.size(); ++i) {
    printf("%d ", feature_masks_[i]);
  }
  printf("\n");
  printf("%d feature groups, %d used, %d\n", num_dense_feature_groups_, used_dense_feature_groups, use_all_features);
#endif
  // if not all feature groups are used, we need to transfer the feature mask to GPU
  // otherwise, we will use a specialized GPU kernel with all feature groups enabled
  if (!use_all_features) {
    queue_.enqueue_write_buffer(device_feature_masks_, 0, num_dense_feature4_ * dword_features_, ptr_pinned_feature_masks_);
  }
  // All data have been prepared, now run the GPU kernel
  GPUHistogram(num_data, use_all_features);
  return true;
}

void GPUTreeLearner::ConstructHistograms(const std::vector<int8_t>& is_feature_used, bool use_subtract) {
  std::vector<int8_t> is_sparse_feature_used(num_features_, 0);
  std::vector<int8_t> is_dense_feature_used(num_features_, 0);
  #pragma omp parallel for schedule(static)
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
958
    if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
959
    if (!is_feature_used[feature_index]) continue;
960
    if (train_data_->IsMultiGroup(train_data_->Feature2Group(feature_index))) {
961
      is_sparse_feature_used[feature_index] = 1;
962
    } else {
963
964
965
966
      is_dense_feature_used[feature_index] = 1;
    }
  }
  // construct smaller leaf
967
  hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset;
968
969
970
971
972
973
974
  // ConstructGPUHistogramsAsync will return true if there are availabe feature gourps dispatched to GPU
  bool is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used,
    nullptr, smaller_leaf_splits_->num_data_in_leaf(),
    nullptr, nullptr,
    nullptr, nullptr);
  // then construct sparse features on CPU
  train_data_->ConstructHistograms(is_sparse_feature_used,
975
976
    smaller_leaf_splits_->data_indices(), smaller_leaf_splits_->num_data_in_leaf(),
    gradients_, hessians_,
977
978
    ordered_gradients_.data(), ordered_hessians_.data(),
    share_state_.get(),
979
980
981
    ptr_smaller_leaf_hist_data);
  // wait for GPU to finish, only if GPU is actually used
  if (is_gpu_used) {
Guolin Ke's avatar
Guolin Ke committed
982
    if (config_->gpu_use_dp) {
983
      // use double precision
984
      WaitAndGetHistograms<hist_t>(ptr_smaller_leaf_hist_data);
985
    } else {
986
      // use single precision
987
      WaitAndGetHistograms<gpu_hist_t>(ptr_smaller_leaf_hist_data);
988
989
990
991
992
993
994
995
996
997
998
    }
  }

  // Compare GPU histogram with CPU histogram, useful for debuggin GPU code problem
  // #define GPU_DEBUG_COMPARE
  #ifdef GPU_DEBUG_COMPARE
  for (int i = 0; i < num_dense_feature_groups_; ++i) {
    if (!feature_masks_[i])
      continue;
    int dense_feature_group_index = dense_feature_group_map_[i];
    size_t size = train_data_->FeatureGroupNumBin(dense_feature_group_index);
999
    hist_t* ptr_smaller_leaf_hist_data = smaller_leaf_histogram_array_[0].RawData() - kHistOffset;
1000
1001
    hist_t* current_histogram = ptr_smaller_leaf_hist_data + train_data_->GroupBinBoundary(dense_feature_group_index) * 2;
    hist_t* gpu_histogram = new hist_t[size * 2];
1002
1003
    data_size_t num_data = smaller_leaf_splits_->num_data_in_leaf();
    printf("Comparing histogram for feature %d size %d, %lu bins\n", dense_feature_group_index, num_data, size);
1004
1005
    std::copy(current_histogram, current_histogram + size * 2, gpu_histogram);
    std::memset(current_histogram, 0, size * sizeof(hist_t) * 2);
1006
1007
1008
1009
    if (train_data_->FeatureGroupBin(dense_feature_group_index) == nullptr) {
      continue;
    }
    if (num_data != num_data_) {
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
      train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
        smaller_leaf_splits_->data_indices(),
        0,
        num_data,
        ordered_gradients_.data(),
        ordered_hessians_.data(),
        current_histogram);
    } else {
      train_data_->FeatureGroupBin(dense_feature_group_index)->ConstructHistogram(
        0,
        num_data,
        gradients_,
        hessians_,
        current_histogram);
    }
1025
    CompareHistograms(gpu_histogram, current_histogram, size, dense_feature_group_index);
1026
    std::copy(gpu_histogram, gpu_histogram + size * 2, current_histogram);
1027
1028
1029
1030
1031
1032
    delete [] gpu_histogram;
  }
  #endif

  if (larger_leaf_histogram_array_ != nullptr && !use_subtract) {
    // construct larger leaf
1033
    hist_t* ptr_larger_leaf_hist_data = larger_leaf_histogram_array_[0].RawData() - kHistOffset;
1034
1035
1036
1037
1038
1039
    is_gpu_used = ConstructGPUHistogramsAsync(is_feature_used,
      larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
      gradients_, hessians_,
      ordered_gradients_.data(), ordered_hessians_.data());
    // then construct sparse features on CPU
    train_data_->ConstructHistograms(is_sparse_feature_used,
1040
1041
      larger_leaf_splits_->data_indices(), larger_leaf_splits_->num_data_in_leaf(),
      gradients_, hessians_,
1042
1043
      ordered_gradients_.data(), ordered_hessians_.data(),
      share_state_.get(),
1044
1045
1046
      ptr_larger_leaf_hist_data);
    // wait for GPU to finish, only if GPU is actually used
    if (is_gpu_used) {
Guolin Ke's avatar
Guolin Ke committed
1047
      if (config_->gpu_use_dp) {
1048
        // use double precision
1049
        WaitAndGetHistograms<hist_t>(ptr_larger_leaf_hist_data);
1050
      } else {
1051
        // use single precision
1052
        WaitAndGetHistograms<gpu_hist_t>(ptr_larger_leaf_hist_data);
1053
1054
1055
1056
1057
      }
    }
  }
}

1058
1059
void GPUTreeLearner::FindBestSplits(const Tree* tree) {
  SerialTreeLearner::FindBestSplits(tree);
1060
1061
1062

#if GPU_DEBUG >= 3
  for (int feature_index = 0; feature_index < num_features_; ++feature_index) {
1063
    if (!col_sampler_.is_feature_used_bytree()[feature_index]) continue;
1064
1065
1066
1067
1068
    if (parent_leaf_histogram_array_ != nullptr
        && !parent_leaf_histogram_array_[feature_index].is_splittable()) {
      smaller_leaf_histogram_array_[feature_index].set_is_splittable(false);
      continue;
    }
1069
    size_t bin_size = train_data_->FeatureNumBin(feature_index) + 1;
1070
    printf("Feature %d smaller leaf:\n", feature_index);
1071
1072
    PrintHistograms(smaller_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
    if (larger_leaf_splits_ == nullptr || larger_leaf_splits_->LeafIndex() < 0) { continue; }
1073
    printf("Feature %d larger leaf:\n", feature_index);
1074
1075
1076
1077
1078
1079
1080
1081
    PrintHistograms(larger_leaf_histogram_array_[feature_index].RawData() - 1, bin_size);
  }
#endif
}

void GPUTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right_leaf) {
  const SplitInfo& best_split_info = best_split_per_leaf_[best_Leaf];
#if GPU_DEBUG >= 2
James Lamb's avatar
James Lamb committed
1082
  printf("Splitting leaf %d with feature %d thresh %d gain %f stat %f %f %f %f\n", best_Leaf, best_split_info.feature, best_split_info.threshold, best_split_info.gain, best_split_info.left_sum_gradient, best_split_info.right_sum_gradient, best_split_info.left_sum_hessian, best_split_info.right_sum_hessian);
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
#endif
  SerialTreeLearner::Split(tree, best_Leaf, left_leaf, right_leaf);
  if (Network::num_machines() == 1) {
    // do some sanity check for the GPU algorithm
    if (best_split_info.left_count < best_split_info.right_count) {
      if ((best_split_info.left_count != smaller_leaf_splits_->num_data_in_leaf()) ||
          (best_split_info.right_count!= larger_leaf_splits_->num_data_in_leaf())) {
        Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
      }
    } else {
Belinda Trotta's avatar
Belinda Trotta committed
1093
1094
      smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian, best_split_info.right_output);
      larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian, best_split_info.left_output);
1095
1096
1097
1098
1099
1100
1101
1102
1103
      if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) ||
          (best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) {
        Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
      }
    }
  }
}

}   // namespace LightGBM
1104
#endif  // USE_GPU